#[cfg(feature = "cuda")]
use cudarc::driver::LaunchConfig;
use crate::buffer::CudaBuffer;
use crate::device::GpuDevice;
use crate::error::{GpuError, GpuResult};
#[cfg(feature = "cuda")]
use crate::transfer::{alloc_zeros_f32, alloc_zeros_f64, cpu_to_gpu, gpu_to_cpu};
#[cfg(feature = "cuda")]
pub(crate) fn ptx_f32_to_f64(f32_ptx: &str, f32_kernel_name: &str, f64_kernel_name: &str) -> String {
f32_ptx
.replace(f32_kernel_name, f64_kernel_name)
.replace(".reg .f32", ".reg .f64")
.replace("ld.global.f32", "ld.global.f64")
.replace("st.global.f32", "st.global.f64")
.replace("ld.shared.f32", "ld.shared.f64")
.replace("st.shared.f32", "st.shared.f64")
.replace("ld.param.f32", "ld.param.f64")
.replace(".param .f32", ".param .f64")
.replace(".shared .align 4 .f32", ".shared .align 8 .f64")
.replace("add.f32", "add.f64")
.replace("sub.f32", "sub.f64")
.replace("mul.f32", "mul.f64")
.replace("div.rn.f32", "div.rn.f64")
.replace("div.f32", "div.f64")
.replace("neg.f32", "neg.f64")
.replace("abs.f32", "abs.f64")
.replace("max.f32", "max.f64")
.replace("min.f32", "min.f64")
.replace("sqrt.rn.f32", "sqrt.rn.f64")
.replace("sqrt.f32", "sqrt.f64")
.replace("fma.rn.f32", "fma.rn.f64")
.replace("mov.f32", "mov.f64")
.replace("setp.gt.f32", "setp.gt.f64")
.replace("setp.ge.f32", "setp.ge.f64")
.replace("setp.lt.f32", "setp.lt.f64")
.replace("setp.le.f32", "setp.le.f64")
.replace("setp.eq.f32", "setp.eq.f64")
.replace("setp.ne.f32", "setp.ne.f64")
.replace("cvt.rn.f32.u32", "cvt.rn.f64.u32")
.replace("cvt.rn.f32.s32", "cvt.rn.f64.s32")
.replace("mov.b32", "mov.b64")
.replace("shl.b64 %off, %off, 2", "shl.b64 %off, %off, 3")
.replace("atom.global.add.f32", "atom.global.add.f64")
.replace("0f00000000", "0d0000000000000000") .replace("0f3F800000", "0d3FF0000000000000") .replace("0fBF800000", "0dBFF0000000000000") .replace("0f40000000", "0d4000000000000000") .replace("0f3F000000", "0d3FE0000000000000") .replace("0fFF800000", "0dFFF0000000000000") .replace("0f7F800000", "0d7FF0000000000000") .replace("0f3FB8AA3B", "0d3FF71547652B82FE") .replace("0f3F317218", "0d3FE62E42FEFA39EF") }
#[cfg(feature = "cuda")]
pub(crate) fn get_f64_ptx<'a>(
cache: &'a std::sync::OnceLock<String>,
f32_ptx: &str,
f32_name: &str,
f64_name: &str,
) -> &'a str {
cache.get_or_init(|| ptx_f32_to_f64(f32_ptx, f32_name, f64_name))
}
#[cfg(feature = "cuda")]
pub(crate) const ADD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry add_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %b, %out, %off;
.reg .f32 %va, %vb, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
ld.global.f32 %vb, [%b];
add.f32 %vr, %va, %vb;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const ADD_VEC4_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry add_vec4_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u32 n4
) {
.reg .u32 %r_tid, %bid, %bdim, %n4_reg;
.reg .u64 %a, %b, %out, %off;
.reg .f32 %a0, %a1, %a2, %a3, %b0, %b1, %b2, %b3, %r0, %r1, %r2, %r3;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n4_reg, [n4];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n4_reg;
@%p bra DONE;
// Byte offset = tid * 16 (4 floats × 4 bytes)
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 4;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.v4.f32 {%a0, %a1, %a2, %a3}, [%a];
ld.global.v4.f32 {%b0, %b1, %b2, %b3}, [%b];
add.f32 %r0, %a0, %b0;
add.f32 %r1, %a1, %b1;
add.f32 %r2, %a2, %b2;
add.f32 %r3, %a3, %b3;
st.global.v4.f32 [%out], {%r0, %r1, %r2, %r3};
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const MUL_VEC4_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry mul_vec4_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u32 n4
) {
.reg .u32 %r_tid, %bid, %bdim, %n4_reg;
.reg .u64 %a, %b, %out, %off;
.reg .f32 %a0, %a1, %a2, %a3, %b0, %b1, %b2, %b3, %r0, %r1, %r2, %r3;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n4_reg, [n4];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n4_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 4;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.v4.f32 {%a0, %a1, %a2, %a3}, [%a];
ld.global.v4.f32 {%b0, %b1, %b2, %b3}, [%b];
mul.f32 %r0, %a0, %b0;
mul.f32 %r1, %a1, %b1;
mul.f32 %r2, %a2, %b2;
mul.f32 %r3, %a3, %b3;
st.global.v4.f32 [%out], {%r0, %r1, %r2, %r3};
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SUB_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry sub_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %b, %out, %off;
.reg .f32 %va, %vb, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
ld.global.f32 %vb, [%b];
sub.f32 %vr, %va, %vb;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const MUL_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry mul_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %b, %out, %off;
.reg .f32 %va, %vb, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
ld.global.f32 %vb, [%b];
mul.f32 %vr, %va, %vb;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const NEG_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry neg_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %va, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
neg.f32 %vr, %va;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const RELU_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry relu_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %va, %vr, %zero;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
mov.f32 %zero, 0f00000000;
max.f32 %vr, %va, %zero;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SCALE_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry scale_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .f32 scalar,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %va, %vr, %s;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.f32 %s, [scalar];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
mul.f32 %vr, %va, %s;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const TRANSPOSE_2D_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.visible .entry transpose_2d_kernel(\n\
.param .u64 in_ptr,\n\
.param .u64 out_ptr,\n\
.param .u32 M,\n\
.param .u32 N,\n\
.param .u32 total\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %total_reg, %M_reg, %N_reg;\n\
.reg .u32 %out_row, %out_col, %in_idx;\n\
.reg .u64 %in, %out, %off_in, %off_out;\n\
.reg .f32 %val;\n\
.reg .pred %p;\n\
\n\
ld.param.u64 %in, [in_ptr];\n\
ld.param.u64 %out, [out_ptr];\n\
ld.param.u32 %M_reg, [M];\n\
ld.param.u32 %N_reg, [N];\n\
ld.param.u32 %total_reg, [total];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
\n\
setp.ge.u32 %p, %r_tid, %total_reg;\n\
@%p bra DONE;\n\
\n\
// Output shape is [N, M]. tid = out_row * M + out_col.\n\
div.u32 %out_row, %r_tid, %M_reg;\n\
rem.u32 %out_col, %r_tid, %M_reg;\n\
// Input index: out_col * N + out_row (transposed).\n\
mad.lo.u32 %in_idx, %out_col, %N_reg, %out_row;\n\
\n\
cvt.u64.u32 %off_in, %in_idx;\n\
shl.b64 %off_in, %off_in, 2;\n\
add.u64 %off_in, %in, %off_in;\n\
ld.global.f32 %val, [%off_in];\n\
\n\
cvt.u64.u32 %off_out, %r_tid;\n\
shl.b64 %off_out, %off_out, 2;\n\
add.u64 %off_out, %out, %off_out;\n\
st.global.f32 [%off_out], %val;\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const PERMUTE_0213_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.visible .entry permute_0213_kernel(\n\
.param .u64 in_ptr,\n\
.param .u64 out_ptr,\n\
.param .u32 d0,\n\
.param .u32 d1,\n\
.param .u32 d2,\n\
.param .u32 d3,\n\
.param .u32 total\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %total_reg;\n\
.reg .u32 %d0r, %d1r, %d2r, %d3r;\n\
.reg .u32 %i0, %i1, %i2, %i3, %rem, %in_idx;\n\
.reg .u32 %s_out2, %s_out1, %s_in1;\n\
.reg .u64 %in, %out, %off_in, %off_out;\n\
.reg .f32 %val;\n\
.reg .pred %p;\n\
\n\
ld.param.u64 %in, [in_ptr];\n\
ld.param.u64 %out, [out_ptr];\n\
ld.param.u32 %d0r, [d0];\n\
ld.param.u32 %d1r, [d1];\n\
ld.param.u32 %d2r, [d2];\n\
ld.param.u32 %d3r, [d3];\n\
ld.param.u32 %total_reg, [total];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
\n\
setp.ge.u32 %p, %r_tid, %total_reg;\n\
@%p bra DONE;\n\
\n\
// Output shape: [d0, d2, d1, d3]\n\
// Decompose tid into (i0, i2, i1, i3) in output layout.\n\
mul.lo.u32 %s_out2, %d1r, %d3r;\n\
mul.lo.u32 %s_out1, %s_out2, %d2r;\n\
\n\
div.u32 %i0, %r_tid, %s_out1;\n\
rem.u32 %rem, %r_tid, %s_out1;\n\
div.u32 %i2, %rem, %s_out2;\n\
rem.u32 %rem, %rem, %s_out2;\n\
div.u32 %i1, %rem, %d3r;\n\
rem.u32 %i3, %rem, %d3r;\n\
\n\
// Input index: i0 * (d1*d2*d3) + i1 * (d2*d3) + i2 * d3 + i3\n\
mul.lo.u32 %s_in1, %d2r, %d3r;\n\
mul.lo.u32 %in_idx, %i0, %d1r;\n\
add.u32 %in_idx, %in_idx, %i1;\n\
mul.lo.u32 %in_idx, %in_idx, %s_in1;\n\
mad.lo.u32 %in_idx, %i2, %d3r, %in_idx;\n\
add.u32 %in_idx, %in_idx, %i3;\n\
\n\
cvt.u64.u32 %off_in, %in_idx;\n\
shl.b64 %off_in, %off_in, 2;\n\
add.u64 %off_in, %in, %off_in;\n\
ld.global.f32 %val, [%off_in];\n\
\n\
cvt.u64.u32 %off_out, %r_tid;\n\
shl.b64 %off_out, %off_out, 2;\n\
add.u64 %off_out, %out, %off_out;\n\
st.global.f32 [%off_out], %val;\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const F32_TO_F16_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry f32_to_f16_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %in, %out, %off_in, %off_out;
.reg .f32 %vf;
.reg .b16 %vh;
.reg .pred %p;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
// Compute input offset: i * 4 (f32 = 4 bytes)
cvt.u64.u32 %off_in, %r_tid;
shl.b64 %off_in, %off_in, 2;
add.u64 %in, %in, %off_in;
// Compute output offset: i * 2 (f16 = 2 bytes)
cvt.u64.u32 %off_out, %r_tid;
shl.b64 %off_out, %off_out, 1;
add.u64 %out, %out, %off_out;
// Load f32, convert to f16 (round-to-nearest-even), store as u16
ld.global.f32 %vf, [%in];
cvt.rn.f16.f32 %vh, %vf;
st.global.b16 [%out], %vh;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const F32_TO_BF16_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry f32_to_bf16_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %in, %out, %off_in, %off_out;
.reg .f32 %vf;
.reg .u32 %bits, %round, %lsb, %result;
.reg .pred %p;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off_in, %r_tid;
shl.b64 %off_in, %off_in, 2;
add.u64 %in, %in, %off_in;
cvt.u64.u32 %off_out, %r_tid;
shl.b64 %off_out, %off_out, 1;
add.u64 %out, %out, %off_out;
// Load f32 as raw bits
ld.global.u32 %bits, [%in];
// Round-to-nearest-even: add (0x7FFF + bit[16]) then shift right 16
shr.u32 %lsb, %bits, 16;
and.b32 %lsb, %lsb, 1;
add.u32 %round, %bits, 0x7FFF;
add.u32 %round, %round, %lsb;
shr.u32 %result, %round, 16;
// Store as u16
st.global.u16 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SMALL_MATMUL_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry small_matmul_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 c_ptr,
.param .u32 M,
.param .u32 K,
.param .u32 N,
.param .u32 total
) {
.reg .u32 %r_tid, %bid, %bdim, %total_reg, %M_reg, %K_reg, %N_reg;
.reg .u32 %row, %col, %p, %idx;
.reg .u64 %a, %b, %c, %a_off, %b_off, %c_off;
.reg .f32 %sum, %va, %vb;
.reg .pred %bounds_p, %loop_p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %c, [c_ptr];
ld.param.u32 %M_reg, [M];
ld.param.u32 %K_reg, [K];
ld.param.u32 %N_reg, [N];
ld.param.u32 %total_reg, [total];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %bounds_p, %r_tid, %total_reg;
@%bounds_p bra DONE;
div.u32 %row, %r_tid, %N_reg;
rem.u32 %col, %r_tid, %N_reg;
mov.f32 %sum, 0f00000000;
mov.u32 %p, 0;
DOT:
setp.ge.u32 %loop_p, %p, %K_reg;
@%loop_p bra DOT_DONE;
mad.lo.u32 %idx, %row, %K_reg, %p;
cvt.u64.u32 %a_off, %idx;
shl.b64 %a_off, %a_off, 2;
add.u64 %a_off, %a, %a_off;
ld.global.f32 %va, [%a_off];
mad.lo.u32 %idx, %p, %N_reg, %col;
cvt.u64.u32 %b_off, %idx;
shl.b64 %b_off, %b_off, 2;
add.u64 %b_off, %b, %b_off;
ld.global.f32 %vb, [%b_off];
fma.rn.f32 %sum, %va, %vb, %sum;
add.u32 %p, %p, 1;
bra DOT;
DOT_DONE:
cvt.u64.u32 %c_off, %r_tid;
shl.b64 %c_off, %c_off, 2;
add.u64 %c_off, %c, %c_off;
st.global.f32 [%c_off], %sum;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SLICE_WRITE_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry slice_write_kernel(
.param .u64 src_ptr,
.param .u64 dst_ptr,
.param .u32 n,
.param .u32 D,
.param .u32 max_len,
.param .u32 pos
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %D_reg, %max_len_reg, %pos_reg;
.reg .u32 %batch_idx, %d_idx, %dst_row;
.reg .u64 %src, %dst, %src_off, %dst_off;
.reg .f32 %val;
.reg .pred %p;
ld.param.u64 %src, [src_ptr];
ld.param.u64 %dst, [dst_ptr];
ld.param.u32 %n_reg, [n];
ld.param.u32 %D_reg, [D];
ld.param.u32 %max_len_reg, [max_len];
ld.param.u32 %pos_reg, [pos];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %src_off, %r_tid;
shl.b64 %src_off, %src_off, 2;
add.u64 %src, %src, %src_off;
ld.global.f32 %val, [%src];
div.u32 %batch_idx, %r_tid, %D_reg;
rem.u32 %d_idx, %r_tid, %D_reg;
mul.lo.u32 %dst_row, %batch_idx, %max_len_reg;
add.u32 %dst_row, %dst_row, %pos_reg;
mul.lo.u32 %dst_row, %dst_row, %D_reg;
add.u32 %dst_row, %dst_row, %d_idx;
cvt.u64.u32 %dst_off, %dst_row;
shl.b64 %dst_off, %dst_off, 2;
add.u64 %dst, %dst, %dst_off;
st.global.f32 [%dst], %val;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SLICE_WRITE_INDIRECT_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry slice_write_indirect_kernel(
.param .u64 src_ptr,
.param .u64 dst_ptr,
.param .u32 n,
.param .u32 D,
.param .u32 max_len,
.param .u64 pos_ptr
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %D_reg, %max_len_reg, %pos_reg;
.reg .u32 %batch_idx, %d_idx, %dst_row;
.reg .u64 %src, %dst, %src_off, %dst_off, %pos_p;
.reg .f32 %val;
.reg .pred %p;
ld.param.u64 %src, [src_ptr];
ld.param.u64 %dst, [dst_ptr];
ld.param.u32 %n_reg, [n];
ld.param.u32 %D_reg, [D];
ld.param.u32 %max_len_reg, [max_len];
ld.param.u64 %pos_p, [pos_ptr];
ld.global.u32 %pos_reg, [%pos_p];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %src_off, %r_tid;
shl.b64 %src_off, %src_off, 2;
add.u64 %src, %src, %src_off;
ld.global.f32 %val, [%src];
div.u32 %batch_idx, %r_tid, %D_reg;
rem.u32 %d_idx, %r_tid, %D_reg;
mul.lo.u32 %dst_row, %batch_idx, %max_len_reg;
add.u32 %dst_row, %dst_row, %pos_reg;
mul.lo.u32 %dst_row, %dst_row, %D_reg;
add.u32 %dst_row, %dst_row, %d_idx;
cvt.u64.u32 %dst_off, %dst_row;
shl.b64 %dst_off, %dst_off, 2;
add.u64 %dst, %dst, %dst_off;
st.global.f32 [%dst], %val;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const CAUSAL_MASK_INDIRECT_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry causal_mask_indirect_kernel(
.param .u64 total_len_ptr,
.param .u64 out_ptr,
.param .u32 max_pos,
.param .u32 total
) {
.reg .u32 %r_tid, %bid, %bdim, %total_reg, %tlen, %max_pos_reg, %col;
.reg .u64 %out, %off, %tl_p;
.reg .f32 %val;
.reg .pred %bounds_p, %mask_p;
ld.param.u64 %tl_p, [total_len_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %max_pos_reg, [max_pos];
ld.param.u32 %total_reg, [total];
ld.global.u32 %tlen, [%tl_p];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %bounds_p, %r_tid, %total_reg;
@%bounds_p bra DONE;
rem.u32 %col, %r_tid, %max_pos_reg;
setp.lt.u32 %mask_p, %col, %tlen;
@%mask_p bra WRITE_ZERO;
// 0fCE6E6B28 = -1.0e9 in IEEE 754 f32, used as a large negative mask value
// to effectively zero out masked positions after softmax.
mov.f32 %val, 0fCE6E6B28;
bra WRITE;
WRITE_ZERO:
mov.f32 %val, 0f00000000;
WRITE:
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %out, %out, %off;
st.global.f32 [%out], %val;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const EMBED_LOOKUP_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry embed_lookup_kernel(
.param .u64 idx_ptr,
.param .u64 weight_ptr,
.param .u64 out_ptr,
.param .u32 D
) {
.reg .u32 %r_tid, %bid, %bdim, %D_reg, %row, %src_idx;
.reg .u64 %idx_addr, %w, %out, %off;
.reg .f32 %idx_f, %val;
.reg .pred %p;
ld.param.u64 %idx_addr, [idx_ptr];
ld.param.u64 %w, [weight_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %D_reg, [D];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %D_reg;
@%p bra DONE;
ld.global.f32 %idx_f, [%idx_addr];
cvt.rzi.u32.f32 %row, %idx_f;
mad.lo.u32 %src_idx, %row, %D_reg, %r_tid;
cvt.u64.u32 %off, %src_idx;
shl.b64 %off, %off, 2;
add.u64 %off, %w, %off;
ld.global.f32 %val, [%off];
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %off, %out, %off;
st.global.f32 [%off], %val;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const EMBED_LOOKUP_BATCH_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry embed_lookup_batch_kernel(
.param .u64 idx_ptr,
.param .u64 weight_ptr,
.param .u64 out_ptr,
.param .u32 D,
.param .u32 total
) {
.reg .u32 %tid, %bid, %bdim, %D_reg, %total_reg;
.reg .u32 %row, %col, %src_idx;
.reg .u64 %idx_addr, %w, %out, %off;
.reg .f32 %idx_f, %val;
.reg .pred %p;
ld.param.u64 %idx_addr, [idx_ptr];
ld.param.u64 %w, [weight_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %D_reg, [D];
ld.param.u32 %total_reg, [total];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %tid, %tid.x;
mad.lo.u32 %tid, %bid, %bdim, %tid;
setp.ge.u32 %p, %tid, %total_reg;
@%p bra DONE;
// row = tid / D, col = tid % D
div.u32 %row, %tid, %D_reg;
rem.u32 %col, %tid, %D_reg;
// Read indices[row] (f32 -> u32)
cvt.u64.u32 %off, %row;
shl.b64 %off, %off, 2;
add.u64 %off, %idx_addr, %off;
ld.global.f32 %idx_f, [%off];
cvt.rzi.u32.f32 %src_idx, %idx_f;
// src_idx = indices[row] * D + col
mad.lo.u32 %src_idx, %src_idx, %D_reg, %col;
cvt.u64.u32 %off, %src_idx;
shl.b64 %off, %off, 2;
add.u64 %off, %w, %off;
ld.global.f32 %val, [%off];
// Write to out[tid]
cvt.u64.u32 %off, %tid;
shl.b64 %off, %off, 2;
add.u64 %off, %out, %off;
st.global.f32 [%off], %val;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SCATTER_ADD_ROWS_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry scatter_add_rows_kernel(
.param .u64 grad_output_ptr,
.param .u64 indices_ptr,
.param .u64 grad_weight_ptr,
.param .u32 D,
.param .u32 total
) {
.reg .u32 %tid, %bid, %bdim, %D_reg, %total_reg;
.reg .u32 %row, %col, %dst_idx;
.reg .u64 %go, %idx_addr, %gw, %off;
.reg .f32 %idx_f, %grad_val, %dummy;
.reg .pred %p;
ld.param.u64 %go, [grad_output_ptr];
ld.param.u64 %idx_addr, [indices_ptr];
ld.param.u64 %gw, [grad_weight_ptr];
ld.param.u32 %D_reg, [D];
ld.param.u32 %total_reg, [total];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %tid, %tid.x;
mad.lo.u32 %tid, %bid, %bdim, %tid;
setp.ge.u32 %p, %tid, %total_reg;
@%p bra DONE;
// row = tid / D, col = tid % D
div.u32 %row, %tid, %D_reg;
rem.u32 %col, %tid, %D_reg;
// Read grad_output[tid]
cvt.u64.u32 %off, %tid;
shl.b64 %off, %off, 2;
add.u64 %off, %go, %off;
ld.global.f32 %grad_val, [%off];
// Read indices[row] (f32 -> u32)
cvt.u64.u32 %off, %row;
shl.b64 %off, %off, 2;
add.u64 %off, %idx_addr, %off;
ld.global.f32 %idx_f, [%off];
cvt.rzi.u32.f32 %dst_idx, %idx_f;
// dst_idx = indices[row] * D + col
mad.lo.u32 %dst_idx, %dst_idx, %D_reg, %col;
cvt.u64.u32 %off, %dst_idx;
shl.b64 %off, %off, 2;
add.u64 %off, %gw, %off;
atom.global.add.f32 %dummy, [%off], %grad_val;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SLICE_READ_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry slice_read_kernel(
.param .u64 src_ptr,
.param .u64 dst_ptr,
.param .u32 total,
.param .u32 D,
.param .u32 len,
.param .u32 max_len
) {
.reg .u32 %r_tid, %bid, %bdim, %total_reg, %D_reg, %len_reg, %max_len_reg;
.reg .u32 %batch_idx, %within, %row, %col, %src_idx;
.reg .u32 %len_d;
.reg .u64 %src, %dst, %src_off, %dst_off;
.reg .f32 %val;
.reg .pred %p;
ld.param.u64 %src, [src_ptr];
ld.param.u64 %dst, [dst_ptr];
ld.param.u32 %total_reg, [total];
ld.param.u32 %D_reg, [D];
ld.param.u32 %len_reg, [len];
ld.param.u32 %max_len_reg, [max_len];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %total_reg;
@%p bra DONE;
// dst index = r_tid
// batch_idx = r_tid / (len * D)
// within = r_tid % (len * D)
// row = within / D
// col = within % D
// src_idx = batch_idx * max_len * D + row * D + col
mul.lo.u32 %len_d, %len_reg, %D_reg;
div.u32 %batch_idx, %r_tid, %len_d;
rem.u32 %within, %r_tid, %len_d;
div.u32 %row, %within, %D_reg;
rem.u32 %col, %within, %D_reg;
mul.lo.u32 %src_idx, %batch_idx, %max_len_reg;
add.u32 %src_idx, %src_idx, %row;
mul.lo.u32 %src_idx, %src_idx, %D_reg;
add.u32 %src_idx, %src_idx, %col;
cvt.u64.u32 %src_off, %src_idx;
shl.b64 %src_off, %src_off, 2;
add.u64 %src_off, %src, %src_off;
ld.global.f32 %val, [%src_off];
cvt.u64.u32 %dst_off, %r_tid;
shl.b64 %dst_off, %dst_off, 2;
add.u64 %dst_off, %dst, %dst_off;
st.global.f32 [%dst_off], %val;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const GELU_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry gelu_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %in, %out, %off;
.reg .f32 %x, %neg_kx, %exp_neg, %one, %denom, %sig, %result, %k;
.reg .pred %p;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %in, %in, %off;
add.u64 %out, %out, %off;
ld.global.f32 %x, [%in];
mov.f32 %k, 0f3FDA2720;
mul.f32 %neg_kx, %k, %x;
neg.f32 %neg_kx, %neg_kx;
mul.f32 %neg_kx, %neg_kx, 0f3FB8AA3B;
ex2.approx.f32 %exp_neg, %neg_kx;
mov.f32 %one, 0f3F800000;
add.f32 %denom, %one, %exp_neg;
rcp.approx.f32 %sig, %denom;
mul.f32 %result, %x, %sig;
st.global.f32 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const GELU_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry gelu_f64_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %in, %out, %off;
.reg .f64 %x, %neg_kx, %exp_neg, %one, %denom, %sig, %result, %k;
.reg .f64 %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
.reg .pred %p;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %in, %in, %off;
add.u64 %out, %out, %off;
ld.global.f64 %x, [%in];
mov.f64 %one, 0d3FF0000000000000;
// k = 1.702
mov.f64 %k, 0d3FFB44E400000000;
mul.f64 %neg_kx, %k, %x;
neg.f64 %neg_kx, %neg_kx;
// --- exp(%neg_kx) via Cody-Waite + degree-11 Horner ---
mov.f64 %e_half, 0d3FE0000000000000;
fma.rn.f64 %e_nf, %neg_kx, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_kx;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E21EED8EFF8D898;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %exp_neg, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %exp_neg, %exp_neg, %e_nf;
// --- end exp ---
add.f64 %denom, %one, %exp_neg;
div.rn.f64 %sig, %one, %denom;
mul.f64 %result, %x, %sig;
st.global.f64 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const GELU_TANH_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry gelu_tanh_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %in, %out, %off;
.reg .f32 %x, %x3, %inner, %sqrt2pi, %c, %y, %two_y, %e2y;
.reg .f32 %e2y_m1, %e2y_p1, %th, %one, %half, %log2e, %result;
.reg .pred %p;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %in, %in, %off;
add.u64 %out, %out, %off;
ld.global.f32 %x, [%in];
// inner = sqrt(2/π) * (x + 0.044715 * x³)
// sqrt(2/π) = 0.7978845608 = 0x3F4C422A
// 0.044715 = 0x3D372713
mul.f32 %x3, %x, %x;
mul.f32 %x3, %x3, %x;
mov.f32 %c, 0f3D372713;
mul.f32 %x3, %c, %x3;
add.f32 %inner, %x, %x3;
mov.f32 %sqrt2pi, 0f3F4C422A;
mul.f32 %y, %sqrt2pi, %inner;
// tanh(y) = (exp(2y) - 1) / (exp(2y) + 1)
// exp(2y) = 2^(2y * log2(e))
mov.f32 %log2e, 0f3FB8AA3B;
add.f32 %two_y, %y, %y;
mul.f32 %two_y, %two_y, %log2e;
ex2.approx.f32 %e2y, %two_y;
mov.f32 %one, 0f3F800000;
sub.f32 %e2y_m1, %e2y, %one;
add.f32 %e2y_p1, %e2y, %one;
rcp.approx.f32 %e2y_p1, %e2y_p1;
mul.f32 %th, %e2y_m1, %e2y_p1;
// out = 0.5 * x * (1 + tanh)
add.f32 %th, %one, %th;
mov.f32 %half, 0f3F000000;
mul.f32 %result, %half, %x;
mul.f32 %result, %result, %th;
st.global.f32 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const GELU_TANH_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry gelu_tanh_f64_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %in, %out, %off;
.reg .f64 %x, %x3, %inner, %sqrt2pi, %c, %y, %two_y, %e2y;
.reg .f64 %e2y_m1, %e2y_p1, %th, %one, %half, %result;
.reg .f64 %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
.reg .pred %p;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %in, %in, %off;
add.u64 %out, %out, %off;
ld.global.f64 %x, [%in];
mov.f64 %one, 0d3FF0000000000000;
// inner = sqrt(2/pi) * (x + 0.044715 * x^3)
mul.f64 %x3, %x, %x;
mul.f64 %x3, %x3, %x;
mov.f64 %c, 0d3FA6E4E260000000;
mul.f64 %x3, %c, %x3;
add.f64 %inner, %x, %x3;
mov.f64 %sqrt2pi, 0d3FE9884540000000;
mul.f64 %y, %sqrt2pi, %inner;
// tanh(y) = (exp(2y)-1)/(exp(2y)+1), exp(2y) in full f64
add.f64 %two_y, %y, %y;
// --- exp(%two_y) via Cody-Waite + degree-11 Horner ---
mov.f64 %e_half, 0d3FE0000000000000;
fma.rn.f64 %e_nf, %two_y, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_y;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E21EED8EFF8D898;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %e2y, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %e2y, %e2y, %e_nf;
// --- end exp ---
sub.f64 %e2y_m1, %e2y, %one;
add.f64 %e2y_p1, %e2y, %one;
div.rn.f64 %th, %e2y_m1, %e2y_p1;
// out = 0.5 * x * (1 + tanh)
add.f64 %th, %one, %th;
mov.f64 %half, 0d3FE0000000000000;
mul.f64 %result, %half, %x;
mul.f64 %result, %result, %th;
st.global.f64 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const GELU_ERF_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry gelu_erf_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %in, %out, %off;
.reg .f32 %x, %z, %ax, %one, %half, %log2e;
.reg .f32 %t, %pt, %z2, %neg_z2, %exp_neg_z2, %erf_val;
.reg .f32 %p, %a1, %a2, %a3, %a4, %a5, %result;
.reg .pred %pred_ge, %pred_neg;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %pred_ge, %r_tid, %n_reg;
@%pred_ge bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %in, %in, %off;
add.u64 %out, %out, %off;
ld.global.f32 %x, [%in];
mov.f32 %one, 0f3F800000;
mov.f32 %half, 0f3F000000;
mov.f32 %log2e, 0f3FB8AA3B;
// z = x / sqrt(2) = x * 0.70710678
mov.f32 %z, 0f3F3504F3;
mul.f32 %z, %x, %z;
// |z| for erf(|z|)
abs.f32 %ax, %z;
// t = 1 / (1 + 0.3275911 * |z|)
mov.f32 %p, 0f3EA7BA05;
mul.f32 %t, %p, %ax;
add.f32 %t, %one, %t;
rcp.approx.f32 %t, %t;
// Horner: poly = t*(a1 + t*(a2 + t*(a3 + t*(a4 + t*a5))))
mov.f32 %a5, 0f3E0AAAAB;
mov.f32 %a4, 0fBEB3A903;
mov.f32 %a3, 0f3FB506DD;
mov.f32 %a2, 0fBF03C1E1;
mov.f32 %a1, 0f3EA0D6BB;
mul.f32 %pt, %t, %a5;
add.f32 %pt, %pt, %a4;
mul.f32 %pt, %pt, %t;
add.f32 %pt, %pt, %a3;
mul.f32 %pt, %pt, %t;
add.f32 %pt, %pt, %a2;
mul.f32 %pt, %pt, %t;
add.f32 %pt, %pt, %a1;
mul.f32 %pt, %pt, %t;
// exp(-z^2) via ex2.approx: exp(y) = 2^(y * log2(e))
mul.f32 %z2, %ax, %ax;
neg.f32 %neg_z2, %z2;
mul.f32 %neg_z2, %neg_z2, %log2e;
ex2.approx.f32 %exp_neg_z2, %neg_z2;
// erf(|z|) = 1 - poly * exp(-z^2)
mul.f32 %erf_val, %pt, %exp_neg_z2;
sub.f32 %erf_val, %one, %erf_val;
// erf(-z) = -erf(z), so sign-correct
setp.lt.f32 %pred_neg, %z, 0f00000000;
@%pred_neg neg.f32 %erf_val, %erf_val;
// out = x * 0.5 * (1 + erf(x/sqrt(2)))
add.f32 %erf_val, %one, %erf_val;
mul.f32 %result, %half, %x;
mul.f32 %result, %result, %erf_val;
st.global.f32 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const GELU_ERF_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry gelu_erf_f64_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %in, %out, %off;
.reg .f64 %x, %z, %ax, %one, %half;
.reg .f64 %t, %pt, %z2, %neg_z2, %exp_neg_z2, %erf_val;
.reg .f64 %p, %a1, %a2, %a3, %a4, %a5, %result;
.reg .f64 %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
.reg .pred %pred_ge, %pred_neg;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %pred_ge, %r_tid, %n_reg;
@%pred_ge bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %in, %in, %off;
add.u64 %out, %out, %off;
ld.global.f64 %x, [%in];
mov.f64 %one, 0d3FF0000000000000;
mov.f64 %half, 0d3FE0000000000000;
// z = x / sqrt(2) = x * 0.70710678
mov.f64 %z, 0d3FE6A09E60000000;
mul.f64 %z, %x, %z;
abs.f64 %ax, %z;
// t = 1 / (1 + 0.3275911 * |z|)
mov.f64 %p, 0d3FD4F740A0000000;
mul.f64 %t, %p, %ax;
add.f64 %t, %one, %t;
div.rn.f64 %t, %one, %t;
// Horner: poly = t*(a1 + t*(a2 + t*(a3 + t*(a4 + t*a5))))
mov.f64 %a5, 0d3FC1555560000000;
mov.f64 %a4, 0dBFD6752060000000;
mov.f64 %a3, 0d3FF6A0DBA0000000;
mov.f64 %a2, 0dBFE0783C20000000;
mov.f64 %a1, 0d3FD41AD760000000;
mul.f64 %pt, %t, %a5;
add.f64 %pt, %pt, %a4;
mul.f64 %pt, %pt, %t;
add.f64 %pt, %pt, %a3;
mul.f64 %pt, %pt, %t;
add.f64 %pt, %pt, %a2;
mul.f64 %pt, %pt, %t;
add.f64 %pt, %pt, %a1;
mul.f64 %pt, %pt, %t;
// exp(-z^2) in full f64
mul.f64 %z2, %ax, %ax;
neg.f64 %neg_z2, %z2;
// --- exp(%neg_z2) via Cody-Waite + degree-11 Horner ---
mov.f64 %e_half, 0d3FE0000000000000;
fma.rn.f64 %e_nf, %neg_z2, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_z2;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E21EED8EFF8D898;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %exp_neg_z2, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %exp_neg_z2, %exp_neg_z2, %e_nf;
// --- end exp ---
mul.f64 %erf_val, %pt, %exp_neg_z2;
sub.f64 %erf_val, %one, %erf_val;
setp.lt.f64 %pred_neg, %z, 0d0000000000000000;
@%pred_neg neg.f64 %erf_val, %erf_val;
add.f64 %erf_val, %one, %erf_val;
mul.f64 %result, %half, %x;
mul.f64 %result, %result, %erf_val;
st.global.f64 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const GELU_BACKWARD_TANH_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry gelu_backward_tanh_kernel(
.param .u64 grad_ptr,
.param .u64 input_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %input, %out, %off;
.reg .f32 %vg, %x, %x2, %x3, %inner, %sqrt2pi, %c, %c3, %y;
.reg .f32 %two_y, %e2y, %e2y_m1, %e2y_p1, %th, %one, %half, %log2e;
.reg .f32 %th2, %one_m_th2, %d_inner, %term1, %term2, %d_gelu, %result;
.reg .pred %p;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %input, [input_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %grad, %grad, %off;
add.u64 %input, %input, %off;
add.u64 %out, %out, %off;
ld.global.f32 %vg, [%grad];
ld.global.f32 %x, [%input];
mov.f32 %one, 0f3F800000;
mov.f32 %half, 0f3F000000;
mov.f32 %log2e, 0f3FB8AA3B;
mov.f32 %sqrt2pi, 0f3F4C422A;
mov.f32 %c, 0f3D372713;
// 3 * 0.044715 = 0.134145 = 0x3E096B8C
mov.f32 %c3, 0f3E096B8C;
// u = sqrt(2/π) * (x + 0.044715 * x³)
mul.f32 %x2, %x, %x;
mul.f32 %x3, %x2, %x;
mul.f32 %x3, %c, %x3;
add.f32 %inner, %x, %x3;
mul.f32 %y, %sqrt2pi, %inner;
// tanh(y) via exp
add.f32 %two_y, %y, %y;
mul.f32 %two_y, %two_y, %log2e;
ex2.approx.f32 %e2y, %two_y;
sub.f32 %e2y_m1, %e2y, %one;
add.f32 %e2y_p1, %e2y, %one;
rcp.approx.f32 %e2y_p1, %e2y_p1;
mul.f32 %th, %e2y_m1, %e2y_p1;
// d/dx = 0.5*(1+tanh) + 0.5*x*(1-tanh²)*sqrt(2/π)*(1+3*0.044715*x²)
// term1 = 0.5 * (1 + th)
add.f32 %term1, %one, %th;
mul.f32 %term1, %half, %term1;
// (1 - th²)
mul.f32 %th2, %th, %th;
sub.f32 %one_m_th2, %one, %th2;
// d_inner = sqrt(2/π) * (1 + 3*0.044715*x²)
mul.f32 %d_inner, %c3, %x2;
add.f32 %d_inner, %one, %d_inner;
mul.f32 %d_inner, %sqrt2pi, %d_inner;
// term2 = 0.5 * x * (1-th²) * d_inner
mul.f32 %term2, %half, %x;
mul.f32 %term2, %term2, %one_m_th2;
mul.f32 %term2, %term2, %d_inner;
add.f32 %d_gelu, %term1, %term2;
mul.f32 %result, %vg, %d_gelu;
st.global.f32 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const GELU_BACKWARD_TANH_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry gelu_backward_tanh_f64_kernel(
.param .u64 grad_ptr,
.param .u64 input_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %input, %out, %off;
.reg .f64 %vg, %x, %x2, %x3, %inner, %sqrt2pi, %c, %c3, %y;
.reg .f64 %two_y, %e2y, %e2y_m1, %e2y_p1, %th, %one, %half;
.reg .f64 %th2, %one_m_th2, %d_inner, %term1, %term2, %d_gelu, %result;
.reg .f64 %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
.reg .pred %p;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %input, [input_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %grad, %grad, %off;
add.u64 %input, %input, %off;
add.u64 %out, %out, %off;
ld.global.f64 %vg, [%grad];
ld.global.f64 %x, [%input];
mov.f64 %one, 0d3FF0000000000000;
mov.f64 %half, 0d3FE0000000000000;
mov.f64 %sqrt2pi, 0d3FE9884540000000;
mov.f64 %c, 0d3FA6E4E260000000;
// 3 * 0.044715 = 0.134145
mov.f64 %c3, 0d3FC12D7180000000;
mul.f64 %x2, %x, %x;
mul.f64 %x3, %x2, %x;
mul.f64 %x3, %c, %x3;
add.f64 %inner, %x, %x3;
mul.f64 %y, %sqrt2pi, %inner;
// tanh(y) = (exp(2y)-1)/(exp(2y)+1) in full f64
add.f64 %two_y, %y, %y;
// --- exp(%two_y) via Cody-Waite + degree-11 Horner ---
mov.f64 %e_half, 0d3FE0000000000000;
fma.rn.f64 %e_nf, %two_y, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_y;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E21EED8EFF8D898;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %e2y, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %e2y, %e2y, %e_nf;
// --- end exp ---
sub.f64 %e2y_m1, %e2y, %one;
add.f64 %e2y_p1, %e2y, %one;
div.rn.f64 %th, %e2y_m1, %e2y_p1;
add.f64 %term1, %one, %th;
mul.f64 %term1, %half, %term1;
mul.f64 %th2, %th, %th;
sub.f64 %one_m_th2, %one, %th2;
mul.f64 %d_inner, %c3, %x2;
add.f64 %d_inner, %one, %d_inner;
mul.f64 %d_inner, %sqrt2pi, %d_inner;
mul.f64 %term2, %half, %x;
mul.f64 %term2, %term2, %one_m_th2;
mul.f64 %term2, %term2, %d_inner;
add.f64 %d_gelu, %term1, %term2;
mul.f64 %result, %vg, %d_gelu;
st.global.f64 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SILU_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry silu_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %x, %neg, %e, %denom, %sig, %vr, %one, %lg2e;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %x, [%a];
// sigmoid(x) = 1 / (1 + exp(-x))
// exp(-x) = 2^(-x * log2(e))
mov.f32 %one, 0f3F800000;
mov.f32 %lg2e, 0f3FB8AA3B;
neg.f32 %neg, %x;
mul.f32 %neg, %neg, %lg2e;
ex2.approx.f32 %e, %neg;
add.f32 %denom, %one, %e;
rcp.approx.f32 %sig, %denom;
// silu(x) = x * sigmoid(x)
mul.f32 %vr, %x, %sig;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SILU_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry silu_f64_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f64 %x, %neg_x, %e, %denom, %sig, %vr, %one;
.reg .f64 %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f64 %x, [%a];
mov.f64 %one, 0d3FF0000000000000;
neg.f64 %neg_x, %x;
// --- exp(%neg_x) via Cody-Waite + degree-11 Horner ---
mov.f64 %e_half, 0d3FE0000000000000;
fma.rn.f64 %e_nf, %neg_x, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E21EED8EFF8D898;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %e, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %e, %e, %e_nf;
// --- end exp ---
add.f64 %denom, %one, %e;
div.rn.f64 %sig, %one, %denom;
mul.f64 %vr, %x, %sig;
st.global.f64 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SILU_BACKWARD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry silu_backward_kernel(
.param .u64 grad_ptr,
.param .u64 input_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %input, %out, %off;
.reg .f32 %vg, %x, %neg, %e, %denom, %sig, %one, %lg2e;
.reg .f32 %one_m_sig, %x_sig_omsig, %deriv, %result;
.reg .pred %p;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %input, [input_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %grad, %grad, %off;
add.u64 %input, %input, %off;
add.u64 %out, %out, %off;
ld.global.f32 %vg, [%grad];
ld.global.f32 %x, [%input];
// sig = sigmoid(x) = 1 / (1 + exp(-x))
mov.f32 %one, 0f3F800000;
mov.f32 %lg2e, 0f3FB8AA3B;
neg.f32 %neg, %x;
mul.f32 %neg, %neg, %lg2e;
ex2.approx.f32 %e, %neg;
add.f32 %denom, %one, %e;
rcp.approx.f32 %sig, %denom;
// deriv = sig + x * sig * (1 - sig)
sub.f32 %one_m_sig, %one, %sig;
mul.f32 %x_sig_omsig, %x, %sig;
mul.f32 %x_sig_omsig, %x_sig_omsig, %one_m_sig;
add.f32 %deriv, %sig, %x_sig_omsig;
mul.f32 %result, %vg, %deriv;
st.global.f32 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SILU_BACKWARD_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry silu_backward_f64_kernel(
.param .u64 grad_ptr,
.param .u64 input_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %input, %out, %off;
.reg .f64 %vg, %x, %neg_x, %e, %denom, %sig, %one;
.reg .f64 %one_m_sig, %x_sig_omsig, %deriv, %result;
.reg .f64 %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
.reg .pred %p;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %input, [input_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %grad, %grad, %off;
add.u64 %input, %input, %off;
add.u64 %out, %out, %off;
ld.global.f64 %vg, [%grad];
ld.global.f64 %x, [%input];
mov.f64 %one, 0d3FF0000000000000;
neg.f64 %neg_x, %x;
// --- exp(%neg_x) via Cody-Waite + degree-11 Horner ---
mov.f64 %e_half, 0d3FE0000000000000;
fma.rn.f64 %e_nf, %neg_x, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E21EED8EFF8D898;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %e, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %e, %e, %e_nf;
// --- end exp ---
add.f64 %denom, %one, %e;
div.rn.f64 %sig, %one, %denom;
sub.f64 %one_m_sig, %one, %sig;
mul.f64 %x_sig_omsig, %x, %sig;
mul.f64 %x_sig_omsig, %x_sig_omsig, %one_m_sig;
add.f64 %deriv, %sig, %x_sig_omsig;
mul.f64 %result, %vg, %deriv;
st.global.f64 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const ELU_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry elu_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n,
.param .f32 alpha
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %x, %alpha_r, %lg2e, %one, %ex, %em1, %neg_branch, %vr;
.reg .pred %p, %pos;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
ld.param.f32 %alpha_r, [alpha];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %x, [%a];
mov.f32 %one, 0f3F800000;
mov.f32 %lg2e, 0f3FB8AA3B;
// exp(x) = 2^(x * log2(e))
mul.f32 %ex, %x, %lg2e;
ex2.approx.f32 %ex, %ex;
sub.f32 %em1, %ex, %one;
mul.f32 %neg_branch, %alpha_r, %em1;
// x > 0 ? x : alpha*(exp(x)-1)
mov.f32 %vr, 0f00000000;
setp.gt.f32 %pos, %x, %vr;
selp.f32 %vr, %x, %neg_branch, %pos;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const ELU_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry elu_f64_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n,
.param .f64 alpha
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f64 %x, %alpha_r, %one, %ex, %em1, %neg_branch, %vr;
.reg .f64 %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
.reg .pred %p, %pos;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
ld.param.f64 %alpha_r, [alpha];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f64 %x, [%a];
mov.f64 %one, 0d3FF0000000000000;
// --- exp(%x) via Cody-Waite + degree-11 Horner ---
mov.f64 %e_half, 0d3FE0000000000000;
fma.rn.f64 %e_nf, %x, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %x;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E21EED8EFF8D898;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %ex, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %ex, %ex, %e_nf;
// --- end exp ---
sub.f64 %em1, %ex, %one;
mul.f64 %neg_branch, %alpha_r, %em1;
mov.f64 %vr, 0d0000000000000000;
setp.gt.f64 %pos, %x, %vr;
selp.f64 %vr, %x, %neg_branch, %pos;
st.global.f64 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const ELU_BACKWARD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry elu_backward_kernel(
.param .u64 grad_ptr,
.param .u64 input_ptr,
.param .u64 out_ptr,
.param .u32 n,
.param .f32 alpha
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %input, %out, %off;
.reg .f32 %vg, %x, %alpha_r, %lg2e, %ex, %neg_branch, %vr, %zero;
.reg .pred %p, %pos;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %input, [input_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
ld.param.f32 %alpha_r, [alpha];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %grad, %grad, %off;
add.u64 %input, %input, %off;
add.u64 %out, %out, %off;
ld.global.f32 %vg, [%grad];
ld.global.f32 %x, [%input];
mov.f32 %lg2e, 0f3FB8AA3B;
mov.f32 %zero, 0f00000000;
// exp(x) = 2^(x * log2(e))
mul.f32 %ex, %x, %lg2e;
ex2.approx.f32 %ex, %ex;
// negative branch: grad * alpha * exp(x)
mul.f32 %neg_branch, %vg, %alpha_r;
mul.f32 %neg_branch, %neg_branch, %ex;
// x > 0 ? grad : grad * alpha * exp(x)
setp.gt.f32 %pos, %x, %zero;
selp.f32 %vr, %vg, %neg_branch, %pos;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const ELU_BACKWARD_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry elu_backward_f64_kernel(
.param .u64 grad_ptr,
.param .u64 input_ptr,
.param .u64 out_ptr,
.param .u32 n,
.param .f64 alpha
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %input, %out, %off;
.reg .f64 %vg, %x, %alpha_r, %ex, %neg_branch, %vr, %zero, %one;
.reg .f64 %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
.reg .pred %p, %pos;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %input, [input_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
ld.param.f64 %alpha_r, [alpha];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %grad, %grad, %off;
add.u64 %input, %input, %off;
add.u64 %out, %out, %off;
ld.global.f64 %vg, [%grad];
ld.global.f64 %x, [%input];
mov.f64 %zero, 0d0000000000000000;
mov.f64 %one, 0d3FF0000000000000;
// --- exp(%x) via Cody-Waite + degree-11 Horner ---
mov.f64 %e_half, 0d3FE0000000000000;
fma.rn.f64 %e_nf, %x, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %x;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E21EED8EFF8D898;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %ex, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %ex, %ex, %e_nf;
// --- end exp ---
mul.f64 %neg_branch, %vg, %alpha_r;
mul.f64 %neg_branch, %neg_branch, %ex;
setp.gt.f64 %pos, %x, %zero;
selp.f64 %vr, %vg, %neg_branch, %pos;
st.global.f64 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const MISH_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry mish_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %x, %lg2e, %one, %ex, %ep1, %sp, %lg_ep1;
.reg .f32 %two_sp, %e2sp, %e2sp_m1, %e2sp_p1, %th, %vr;
.reg .f32 %threshold;
.reg .pred %p, %large;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %x, [%a];
mov.f32 %one, 0f3F800000;
mov.f32 %lg2e, 0f3FB8AA3B;
// threshold = 20.0 = 0x41A00000
mov.f32 %threshold, 0f41A00000;
// softplus(x) = ln(1 + exp(x))
// For large x (> 20), softplus ~ x to avoid overflow
setp.gt.f32 %large, %x, %threshold;
@%large bra LARGE_X;
// exp(x) = 2^(x * log2(e))
mul.f32 %ex, %x, %lg2e;
ex2.approx.f32 %ex, %ex;
add.f32 %ep1, %ex, %one;
// ln(1+exp(x)) = log2(1+exp(x)) / log2(e)
lg2.approx.f32 %lg_ep1, %ep1;
// 1/log2(e) = ln(2) = 0.6931472 = 0x3F317218
mul.f32 %sp, %lg_ep1, 0f3F317218;
// tanh(sp) = (exp(2*sp) - 1) / (exp(2*sp) + 1)
add.f32 %two_sp, %sp, %sp;
mul.f32 %two_sp, %two_sp, %lg2e;
ex2.approx.f32 %e2sp, %two_sp;
sub.f32 %e2sp_m1, %e2sp, %one;
add.f32 %e2sp_p1, %e2sp, %one;
rcp.approx.f32 %e2sp_p1, %e2sp_p1;
mul.f32 %th, %e2sp_m1, %e2sp_p1;
mul.f32 %vr, %x, %th;
st.global.f32 [%out], %vr;
bra DONE;
LARGE_X:
// softplus ~ x, mish ~ x * tanh(x)
// tanh(x) = (exp(2x)-1)/(exp(2x)+1)
add.f32 %two_sp, %x, %x;
mul.f32 %two_sp, %two_sp, %lg2e;
ex2.approx.f32 %e2sp, %two_sp;
sub.f32 %e2sp_m1, %e2sp, %one;
add.f32 %e2sp_p1, %e2sp, %one;
rcp.approx.f32 %e2sp_p1, %e2sp_p1;
mul.f32 %th, %e2sp_m1, %e2sp_p1;
mul.f32 %vr, %x, %th;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const MISH_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry mish_f64_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f64 %x, %one, %two, %ex, %ep1, %sp;
.reg .f64 %two_sp, %e2sp, %e2sp_m1, %e2sp_p1, %th, %vr;
.reg .f64 %threshold;
// exp subroutine regs
.reg .f64 %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
// log subroutine regs
.reg .u64 %l_xbits, %l_mbits, %l_bias;
.reg .s64 %l_exp64;
.reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf, %l_ln2;
.reg .pred %p, %large;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f64 %x, [%a];
mov.f64 %one, 0d3FF0000000000000;
mov.f64 %two, 0d4000000000000000;
mov.f64 %threshold, 0d4034000000000000;
setp.gt.f64 %large, %x, %threshold;
@%large bra LARGE_X;
// === softplus: sp = ln(1 + exp(x)) ===
// exp(x)
mov.f64 %e_half, 0d3FE0000000000000;
fma.rn.f64 %e_nf, %x, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %x;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E21EED8EFF8D898;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %ex, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %ex, %ex, %e_nf;
// ep1 = 1 + exp(x)
add.f64 %ep1, %ex, %one;
// ln(ep1) via argument reduction
mov.b64 %l_xbits, %ep1;
shr.u64 %l_exp64, %l_xbits, 52;
and.b64 %l_exp64, %l_exp64, 2047;
sub.s64 %l_exp64, %l_exp64, 1023;
cvt.rn.f64.s64 %l_nf, %l_exp64;
mov.u64 %l_bias, 0x3FF0000000000000;
and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;
or.b64 %l_mbits, %l_mbits, %l_bias;
mov.b64 %l_m, %l_mbits;
sub.f64 %l_f, %l_m, %one;
add.f64 %l_s, %l_m, %one;
div.rn.f64 %l_f, %l_f, %l_s;
mul.f64 %l_f2, %l_f, %l_f;
mov.f64 %l_p, 0d3FB745D1745D1746;
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC1C71C71C71C72;
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492;
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A;
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555;
fma.rn.f64 %l_p, %l_p, %l_f2, %one;
mul.f64 %l_p, %l_p, %l_f;
add.f64 %l_p, %l_p, %l_p;
mov.f64 %l_ln2, 0d3FE62E42FEFA39EF;
fma.rn.f64 %sp, %l_nf, %l_ln2, %l_p;
// === tanh(sp) = (exp(2*sp)-1)/(exp(2*sp)+1) ===
add.f64 %two_sp, %sp, %sp;
fma.rn.f64 %e_nf, %two_sp, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_sp;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E21EED8EFF8D898;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %e2sp, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %e2sp, %e2sp, %e_nf;
sub.f64 %e2sp_m1, %e2sp, %one;
add.f64 %e2sp_p1, %e2sp, %one;
div.rn.f64 %th, %e2sp_m1, %e2sp_p1;
mul.f64 %vr, %x, %th;
st.global.f64 [%out], %vr;
bra DONE;
LARGE_X:
// softplus ~ x, tanh(x) = (exp(2x)-1)/(exp(2x)+1) in f64
add.f64 %two_sp, %x, %x;
mov.f64 %e_half, 0d3FE0000000000000;
fma.rn.f64 %e_nf, %two_sp, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_sp;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E21EED8EFF8D898;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %e2sp, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %e2sp, %e2sp, %e_nf;
sub.f64 %e2sp_m1, %e2sp, %one;
add.f64 %e2sp_p1, %e2sp, %one;
div.rn.f64 %th, %e2sp_m1, %e2sp_p1;
mul.f64 %vr, %x, %th;
st.global.f64 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const MISH_BACKWARD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry mish_backward_kernel(
.param .u64 grad_ptr,
.param .u64 input_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %input, %out, %off;
.reg .f32 %vg, %x, %lg2e, %one, %ex, %ep1, %sp, %lg_ep1;
.reg .f32 %two_sp, %e2sp, %e2sp_m1, %e2sp_p1, %t, %t2, %one_m_t2;
.reg .f32 %neg, %en, %denom, %sig, %x_sig_omt2, %deriv, %result;
.reg .f32 %threshold;
.reg .pred %p, %large;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %input, [input_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %grad, %grad, %off;
add.u64 %input, %input, %off;
add.u64 %out, %out, %off;
ld.global.f32 %vg, [%grad];
ld.global.f32 %x, [%input];
mov.f32 %one, 0f3F800000;
mov.f32 %lg2e, 0f3FB8AA3B;
// threshold = 20.0
mov.f32 %threshold, 0f41A00000;
setp.gt.f32 %large, %x, %threshold;
@%large bra LARGE_X;
// --- Normal path ---
// softplus: sp = ln(1 + exp(x))
mul.f32 %ex, %x, %lg2e;
ex2.approx.f32 %ex, %ex;
add.f32 %ep1, %ex, %one;
lg2.approx.f32 %lg_ep1, %ep1;
// ln(2) = 0x3F317218
mul.f32 %sp, %lg_ep1, 0f3F317218;
// t = tanh(sp) = (exp(2*sp)-1)/(exp(2*sp)+1)
add.f32 %two_sp, %sp, %sp;
mul.f32 %two_sp, %two_sp, %lg2e;
ex2.approx.f32 %e2sp, %two_sp;
sub.f32 %e2sp_m1, %e2sp, %one;
add.f32 %e2sp_p1, %e2sp, %one;
rcp.approx.f32 %e2sp_p1, %e2sp_p1;
mul.f32 %t, %e2sp_m1, %e2sp_p1;
// sig = sigmoid(x) = 1/(1+exp(-x))
neg.f32 %neg, %x;
mul.f32 %neg, %neg, %lg2e;
ex2.approx.f32 %en, %neg;
add.f32 %denom, %one, %en;
rcp.approx.f32 %sig, %denom;
// deriv = t + x * sig * (1 - t*t)
mul.f32 %t2, %t, %t;
sub.f32 %one_m_t2, %one, %t2;
mul.f32 %x_sig_omt2, %x, %sig;
mul.f32 %x_sig_omt2, %x_sig_omt2, %one_m_t2;
add.f32 %deriv, %t, %x_sig_omt2;
mul.f32 %result, %vg, %deriv;
st.global.f32 [%out], %result;
bra DONE;
LARGE_X:
// sp ~ x, t ~ tanh(x), sig ~ 1
// tanh(x) = (exp(2x)-1)/(exp(2x)+1)
add.f32 %two_sp, %x, %x;
mul.f32 %two_sp, %two_sp, %lg2e;
ex2.approx.f32 %e2sp, %two_sp;
sub.f32 %e2sp_m1, %e2sp, %one;
add.f32 %e2sp_p1, %e2sp, %one;
rcp.approx.f32 %e2sp_p1, %e2sp_p1;
mul.f32 %t, %e2sp_m1, %e2sp_p1;
// sig ~ 1, deriv ~ t + x*(1-t*t)
mul.f32 %t2, %t, %t;
sub.f32 %one_m_t2, %one, %t2;
mul.f32 %x_sig_omt2, %x, %one_m_t2;
add.f32 %deriv, %t, %x_sig_omt2;
mul.f32 %result, %vg, %deriv;
st.global.f32 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const MISH_BACKWARD_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry mish_backward_f64_kernel(
.param .u64 grad_ptr,
.param .u64 input_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %input, %out, %off;
.reg .f64 %vg, %x, %one, %ex, %ep1, %sp;
.reg .f64 %two_sp, %e2sp, %e2sp_m1, %e2sp_p1, %t, %t2, %one_m_t2;
.reg .f64 %neg_x, %en, %denom, %sig, %x_sig_omt2, %deriv, %result;
.reg .f64 %threshold;
// exp subroutine regs
.reg .f64 %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
// log subroutine regs
.reg .u64 %l_xbits, %l_mbits, %l_bias;
.reg .s64 %l_exp64;
.reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf, %l_ln2;
.reg .pred %p, %large;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %input, [input_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %grad, %grad, %off;
add.u64 %input, %input, %off;
add.u64 %out, %out, %off;
ld.global.f64 %vg, [%grad];
ld.global.f64 %x, [%input];
mov.f64 %one, 0d3FF0000000000000;
mov.f64 %threshold, 0d4034000000000000;
setp.gt.f64 %large, %x, %threshold;
@%large bra LARGE_X;
// === softplus: sp = ln(1 + exp(x)) ===
// exp(x)
mov.f64 %e_half, 0d3FE0000000000000;
mul.f64 %e_nf, %x, 0d3FF71547652B82FE;
cvt.rni.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %x;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E21EED8EFF8D898;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %ex, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %ex, %ex, %e_nf;
add.f64 %ep1, %ex, %one;
// ln(ep1) via argument reduction
mov.b64 %l_xbits, %ep1;
shr.u64 %l_exp64, %l_xbits, 52;
and.b64 %l_exp64, %l_exp64, 2047;
sub.s64 %l_exp64, %l_exp64, 1023;
cvt.rn.f64.s64 %l_nf, %l_exp64;
mov.u64 %l_bias, 0x3FF0000000000000;
and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;
or.b64 %l_mbits, %l_mbits, %l_bias;
mov.b64 %l_m, %l_mbits;
sub.f64 %l_f, %l_m, %one;
add.f64 %l_s, %l_m, %one;
div.rn.f64 %l_f, %l_f, %l_s;
mul.f64 %l_f2, %l_f, %l_f;
mov.f64 %l_p, 0d3FB745D1745D1746;
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC1C71C71C71C72;
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492;
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A;
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555;
fma.rn.f64 %l_p, %l_p, %l_f2, %one;
mul.f64 %l_p, %l_p, %l_f;
add.f64 %l_p, %l_p, %l_p;
mov.f64 %l_ln2, 0d3FE62E42FEFA39EF;
fma.rn.f64 %sp, %l_nf, %l_ln2, %l_p;
// === tanh(sp) ===
add.f64 %two_sp, %sp, %sp;
mul.f64 %e_nf, %two_sp, 0d3FF71547652B82FE;
cvt.rni.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_sp;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E21EED8EFF8D898;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %e2sp, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %e2sp, %e2sp, %e_nf;
sub.f64 %e2sp_m1, %e2sp, %one;
add.f64 %e2sp_p1, %e2sp, %one;
div.rn.f64 %t, %e2sp_m1, %e2sp_p1;
// === sigmoid(x) = 1/(1+exp(-x)) ===
neg.f64 %neg_x, %x;
mul.f64 %e_nf, %neg_x, 0d3FF71547652B82FE;
cvt.rni.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E21EED8EFF8D898;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %en, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %en, %en, %e_nf;
add.f64 %denom, %one, %en;
div.rn.f64 %sig, %one, %denom;
// deriv = t + x * sig * (1 - t*t)
mul.f64 %t2, %t, %t;
sub.f64 %one_m_t2, %one, %t2;
mul.f64 %x_sig_omt2, %x, %sig;
mul.f64 %x_sig_omt2, %x_sig_omt2, %one_m_t2;
add.f64 %deriv, %t, %x_sig_omt2;
mul.f64 %result, %vg, %deriv;
st.global.f64 [%out], %result;
bra DONE;
LARGE_X:
// sp ~ x, tanh(x) in f64, sig ~ 1
add.f64 %two_sp, %x, %x;
mov.f64 %e_half, 0d3FE0000000000000;
mul.f64 %e_nf, %two_sp, 0d3FF71547652B82FE;
cvt.rni.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_sp;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E21EED8EFF8D898;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %e2sp, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %e2sp, %e2sp, %e_nf;
sub.f64 %e2sp_m1, %e2sp, %one;
add.f64 %e2sp_p1, %e2sp, %one;
div.rn.f64 %t, %e2sp_m1, %e2sp_p1;
// sig ~ 1, deriv ~ t + x*(1-t*t)
mul.f64 %t2, %t, %t;
sub.f64 %one_m_t2, %one, %t2;
mul.f64 %x_sig_omt2, %x, %one_m_t2;
add.f64 %deriv, %t, %x_sig_omt2;
mul.f64 %result, %vg, %deriv;
st.global.f64 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const CLAMP_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry clamp_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n,
.param .f32 min_val,
.param .f32 max_val
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %in, %out, %off;
.reg .f32 %x, %mn, %mx, %result;
.reg .pred %p;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
ld.param.f32 %mn, [min_val];
ld.param.f32 %mx, [max_val];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %in, %in, %off;
add.u64 %out, %out, %off;
ld.global.f32 %x, [%in];
max.f32 %result, %x, %mn;
min.f32 %result, %result, %mx;
st.global.f32 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const RELU_BACKWARD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry relu_backward_kernel(
.param .u64 grad_ptr,
.param .u64 input_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %input, %out, %off;
.reg .f32 %vg, %vi, %zero, %vr;
.reg .pred %p, %pos;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %input, [input_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %grad, %grad, %off;
add.u64 %input, %input, %off;
add.u64 %out, %out, %off;
ld.global.f32 %vg, [%grad];
ld.global.f32 %vi, [%input];
mov.f32 %zero, 0f00000000;
setp.gt.f32 %pos, %vi, %zero;
selp.f32 %vr, %vg, %zero, %pos;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const GELU_BACKWARD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry gelu_backward_kernel(
.param .u64 grad_ptr,
.param .u64 input_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %input, %out, %off;
.reg .f32 %vg, %x, %k, %kx, %neg_kx, %log2e, %exp_neg, %one, %denom, %sig;
.reg .f32 %one_minus_sig, %kx_sig_oms, %dsig, %result;
.reg .pred %p;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %input, [input_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %grad, %grad, %off;
add.u64 %input, %input, %off;
add.u64 %out, %out, %off;
ld.global.f32 %vg, [%grad];
ld.global.f32 %x, [%input];
// sig = sigmoid(1.702 * x)
mov.f32 %k, 0f3FDA2720;
mul.f32 %kx, %k, %x;
neg.f32 %neg_kx, %kx;
mov.f32 %log2e, 0f3FB8AA3B;
mul.f32 %neg_kx, %neg_kx, %log2e;
ex2.approx.f32 %exp_neg, %neg_kx;
mov.f32 %one, 0f3F800000;
add.f32 %denom, %one, %exp_neg;
rcp.approx.f32 %sig, %denom;
// d/dx gelu(x) = sig + k * x * sig * (1 - sig)
sub.f32 %one_minus_sig, %one, %sig;
mul.f32 %kx_sig_oms, %kx, %sig;
mul.f32 %kx_sig_oms, %kx_sig_oms, %one_minus_sig;
add.f32 %dsig, %sig, %kx_sig_oms;
// out = grad * d_gelu
mul.f32 %result, %vg, %dsig;
st.global.f32 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const GELU_BACKWARD_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry gelu_backward_f64_kernel(
.param .u64 grad_ptr,
.param .u64 input_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %input, %out, %off;
.reg .f64 %vg, %x, %k, %kx, %neg_kx, %exp_neg, %one, %denom, %sig;
.reg .f64 %one_minus_sig, %kx_sig_oms, %dsig, %result;
.reg .f64 %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
.reg .pred %p;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %input, [input_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %grad, %grad, %off;
add.u64 %input, %input, %off;
add.u64 %out, %out, %off;
ld.global.f64 %vg, [%grad];
ld.global.f64 %x, [%input];
mov.f64 %one, 0d3FF0000000000000;
mov.f64 %k, 0d3FFB44E400000000;
mul.f64 %kx, %k, %x;
neg.f64 %neg_kx, %kx;
// --- exp(%neg_kx) via Cody-Waite + degree-11 Horner ---
mov.f64 %e_half, 0d3FE0000000000000;
fma.rn.f64 %e_nf, %neg_kx, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_kx;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E21EED8EFF8D898;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %exp_neg, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %exp_neg, %exp_neg, %e_nf;
// --- end exp ---
add.f64 %denom, %one, %exp_neg;
div.rn.f64 %sig, %one, %denom;
sub.f64 %one_minus_sig, %one, %sig;
mul.f64 %kx_sig_oms, %kx, %sig;
mul.f64 %kx_sig_oms, %kx_sig_oms, %one_minus_sig;
add.f64 %dsig, %sig, %kx_sig_oms;
mul.f64 %result, %vg, %dsig;
st.global.f64 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const GELU_BACKWARD_ERF_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry gelu_backward_erf_kernel(
.param .u64 grad_ptr,
.param .u64 input_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %input, %out, %off;
.reg .f32 %vg, %x, %ax, %z, %z2, %neg_z2, %exp_neg_z2;
.reg .f32 %t, %pt, %one, %half, %erf_val, %cdf, %pdf;
.reg .f32 %neg_x2h, %exp_neg_x2h, %inv_sqrt_2pi, %x_pdf;
.reg .f32 %d_gelu, %result;
.reg .f32 %p, %a1, %a2, %a3, %a4, %a5, %log2e;
.reg .pred %pred_ge, %pred_neg;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %input, [input_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %pred_ge, %r_tid, %n_reg;
@%pred_ge bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %grad, %grad, %off;
add.u64 %input, %input, %off;
add.u64 %out, %out, %off;
ld.global.f32 %vg, [%grad];
ld.global.f32 %x, [%input];
mov.f32 %one, 0f3F800000;
mov.f32 %half, 0f3F000000;
// z = x / sqrt(2) = x * 0.70710678
mov.f32 %z, 0f3F3504F3;
mul.f32 %z, %x, %z;
// |z| for erf(|z|)
abs.f32 %ax, %z;
// t = 1 / (1 + 0.3275911 * |z|)
mov.f32 %p, 0f3EA7BA05;
mul.f32 %t, %p, %ax;
add.f32 %t, %one, %t;
rcp.approx.f32 %t, %t;
// Horner: poly = t*(a1 + t*(a2 + t*(a3 + t*(a4 + t*a5))))
mov.f32 %a5, 0f3E0AAAAB;
mov.f32 %a4, 0fBEB3A903;
mov.f32 %a3, 0f3FB506DD;
mov.f32 %a2, 0fBF03C1E1;
mov.f32 %a1, 0f3EA0D6BB;
mul.f32 %pt, %t, %a5;
add.f32 %pt, %pt, %a4;
mul.f32 %pt, %pt, %t;
add.f32 %pt, %pt, %a3;
mul.f32 %pt, %pt, %t;
add.f32 %pt, %pt, %a2;
mul.f32 %pt, %pt, %t;
add.f32 %pt, %pt, %a1;
mul.f32 %pt, %pt, %t;
// exp(-z^2) via ex2.approx: exp(y) = 2^(y * log2(e))
mul.f32 %z2, %ax, %ax;
neg.f32 %neg_z2, %z2;
mov.f32 %log2e, 0f3FB8AA3B;
mul.f32 %neg_z2, %neg_z2, %log2e;
ex2.approx.f32 %exp_neg_z2, %neg_z2;
// erf(|z|) = 1 - poly * exp(-z^2)
mul.f32 %erf_val, %pt, %exp_neg_z2;
sub.f32 %erf_val, %one, %erf_val;
// erf(-z) = -erf(z), so sign-correct
setp.lt.f32 %pred_neg, %z, 0f00000000;
@%pred_neg neg.f32 %erf_val, %erf_val;
// Φ(x) = 0.5 * (1 + erf(x/sqrt(2)))
add.f32 %cdf, %one, %erf_val;
mul.f32 %cdf, %half, %cdf;
// φ(x) = exp(-x²/2) / sqrt(2π)
// exp(-x²/2):
mul.f32 %neg_x2h, %x, %x;
mul.f32 %neg_x2h, %neg_x2h, %half;
neg.f32 %neg_x2h, %neg_x2h;
mul.f32 %neg_x2h, %neg_x2h, %log2e;
ex2.approx.f32 %exp_neg_x2h, %neg_x2h;
// 1/sqrt(2π) = 0.39894228
mov.f32 %inv_sqrt_2pi, 0f3ECC4220;
mul.f32 %pdf, %exp_neg_x2h, %inv_sqrt_2pi;
// d/dx gelu(x) = Φ(x) + x * φ(x)
mul.f32 %x_pdf, %x, %pdf;
add.f32 %d_gelu, %cdf, %x_pdf;
// out = grad * d_gelu
mul.f32 %result, %vg, %d_gelu;
st.global.f32 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const GELU_BACKWARD_ERF_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry gelu_backward_erf_f64_kernel(
.param .u64 grad_ptr,
.param .u64 input_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %input, %out, %off;
.reg .f64 %vg, %x, %ax, %z, %z2, %neg_z2, %exp_neg_z2;
.reg .f64 %t, %pt, %one, %half, %erf_val, %cdf, %pdf;
.reg .f64 %neg_x2h, %exp_neg_x2h, %inv_sqrt_2pi, %x_pdf;
.reg .f64 %d_gelu, %result;
.reg .f64 %p_coef, %a1, %a2, %a3, %a4, %a5;
.reg .f64 %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
.reg .pred %pred_ge, %pred_neg;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %input, [input_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %pred_ge, %r_tid, %n_reg;
@%pred_ge bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %grad, %grad, %off;
add.u64 %input, %input, %off;
add.u64 %out, %out, %off;
ld.global.f64 %vg, [%grad];
ld.global.f64 %x, [%input];
mov.f64 %one, 0d3FF0000000000000;
mov.f64 %half, 0d3FE0000000000000;
mov.f64 %z, 0d3FE6A09E60000000;
mul.f64 %z, %x, %z;
abs.f64 %ax, %z;
mov.f64 %p_coef, 0d3FD4F740A0000000;
mul.f64 %t, %p_coef, %ax;
add.f64 %t, %one, %t;
div.rn.f64 %t, %one, %t;
mov.f64 %a5, 0d3FC1555560000000;
mov.f64 %a4, 0dBFD6752060000000;
mov.f64 %a3, 0d3FF6A0DBA0000000;
mov.f64 %a2, 0dBFE0783C20000000;
mov.f64 %a1, 0d3FD41AD760000000;
mul.f64 %pt, %t, %a5;
add.f64 %pt, %pt, %a4;
mul.f64 %pt, %pt, %t;
add.f64 %pt, %pt, %a3;
mul.f64 %pt, %pt, %t;
add.f64 %pt, %pt, %a2;
mul.f64 %pt, %pt, %t;
add.f64 %pt, %pt, %a1;
mul.f64 %pt, %pt, %t;
// exp(-z^2) in full f64
mul.f64 %z2, %ax, %ax;
neg.f64 %neg_z2, %z2;
// --- exp(%neg_z2) ---
mov.f64 %e_half, 0d3FE0000000000000;
fma.rn.f64 %e_nf, %neg_z2, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_z2;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E21EED8EFF8D898;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %exp_neg_z2, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %exp_neg_z2, %exp_neg_z2, %e_nf;
// --- end exp ---
mul.f64 %erf_val, %pt, %exp_neg_z2;
sub.f64 %erf_val, %one, %erf_val;
setp.lt.f64 %pred_neg, %z, 0d0000000000000000;
@%pred_neg neg.f64 %erf_val, %erf_val;
add.f64 %cdf, %one, %erf_val;
mul.f64 %cdf, %half, %cdf;
// phi(x) = exp(-x^2/2) / sqrt(2*pi)
mul.f64 %neg_x2h, %x, %x;
mul.f64 %neg_x2h, %neg_x2h, %half;
neg.f64 %neg_x2h, %neg_x2h;
// --- exp(%neg_x2h) ---
fma.rn.f64 %e_nf, %neg_x2h, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x2h;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E21EED8EFF8D898;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %exp_neg_x2h, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %exp_neg_x2h, %exp_neg_x2h, %e_nf;
// --- end exp ---
// 1/sqrt(2*pi) = 0.39894228
mov.f64 %inv_sqrt_2pi, 0d3FD9884440000000;
mul.f64 %pdf, %exp_neg_x2h, %inv_sqrt_2pi;
mul.f64 %x_pdf, %x, %pdf;
add.f64 %d_gelu, %cdf, %x_pdf;
mul.f64 %result, %vg, %d_gelu;
st.global.f64 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const INDEX_SELECT_1D_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry index_select_1d_kernel(
.param .u64 input_ptr,
.param .u64 indices_ptr,
.param .u64 out_ptr,
.param .u32 n_indices
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %idx;
.reg .u64 %input, %indices, %out, %off, %addr;
.reg .f32 %idx_f, %val;
.reg .pred %p;
ld.param.u64 %input, [input_ptr];
ld.param.u64 %indices, [indices_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n_indices];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
// Byte offset for thread
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
// Read indices[tid] (f32 -> u32)
add.u64 %addr, %indices, %off;
ld.global.f32 %idx_f, [%addr];
cvt.rzi.u32.f32 %idx, %idx_f;
// Read input[idx]
cvt.u64.u32 %addr, %idx;
shl.b64 %addr, %addr, 2;
add.u64 %addr, %input, %addr;
ld.global.f32 %val, [%addr];
// Write output[tid]
add.u64 %addr, %out, %off;
st.global.f32 [%addr], %val;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SCATTER_ADD_1D_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry scatter_add_1d_kernel(
.param .u64 grad_output_ptr,
.param .u64 indices_ptr,
.param .u64 grad_input_ptr,
.param .u32 n_indices
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %idx;
.reg .u64 %go, %indices, %gi, %off, %addr;
.reg .f32 %idx_f, %grad_val, %dummy;
.reg .pred %p;
ld.param.u64 %go, [grad_output_ptr];
ld.param.u64 %indices, [indices_ptr];
ld.param.u64 %gi, [grad_input_ptr];
ld.param.u32 %n_reg, [n_indices];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
// Byte offset for thread
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
// Read grad_output[tid]
add.u64 %addr, %go, %off;
ld.global.f32 %grad_val, [%addr];
// Read indices[tid] (f32 -> u32)
add.u64 %addr, %indices, %off;
ld.global.f32 %idx_f, [%addr];
cvt.rzi.u32.f32 %idx, %idx_f;
// Atomic add: grad_input[idx] += grad_val
cvt.u64.u32 %addr, %idx;
shl.b64 %addr, %addr, 2;
add.u64 %addr, %gi, %addr;
atom.global.add.f32 %dummy, [%addr], %grad_val;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const MASKED_FILL_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry masked_fill_kernel(
.param .u64 input_ptr,
.param .u64 mask_ptr,
.param .u64 out_ptr,
.param .f32 fill_value,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %input, %mask, %out, %off;
.reg .f32 %in_val, %mask_val, %fill, %result, %half;
.reg .pred %p, %pmask;
ld.param.u64 %input, [input_ptr];
ld.param.u64 %mask, [mask_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.f32 %fill, [fill_value];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %input, %input, %off;
add.u64 %mask, %mask, %off;
add.u64 %out, %out, %off;
ld.global.f32 %in_val, [%input];
ld.global.f32 %mask_val, [%mask];
mov.f32 %half, 0f3F000000;
setp.ge.f32 %pmask, %mask_val, %half;
selp.f32 %result, %fill, %in_val, %pmask;
st.global.f32 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const MASKED_ZERO_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry masked_zero_kernel(
.param .u64 grad_ptr,
.param .u64 mask_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %mask, %out, %off;
.reg .f32 %vg, %mask_val, %zero, %result, %half;
.reg .pred %p, %pmask;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %mask, [mask_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %grad, %grad, %off;
add.u64 %mask, %mask, %off;
add.u64 %out, %out, %off;
ld.global.f32 %vg, [%grad];
ld.global.f32 %mask_val, [%mask];
mov.f32 %zero, 0f00000000;
mov.f32 %half, 0f3F000000;
setp.ge.f32 %pmask, %mask_val, %half;
selp.f32 %result, %zero, %vg, %pmask;
st.global.f32 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SIGMOID_BACKWARD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry sigmoid_backward_kernel(
.param .u64 grad_ptr,
.param .u64 output_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %output, %out, %off;
.reg .f32 %vg, %vo, %one, %one_minus_o, %result;
.reg .pred %p;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %output, [output_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %grad, %grad, %off;
add.u64 %output, %output, %off;
add.u64 %out, %out, %off;
ld.global.f32 %vg, [%grad];
ld.global.f32 %vo, [%output];
mov.f32 %one, 0f3F800000;
sub.f32 %one_minus_o, %one, %vo;
mul.f32 %result, %vo, %one_minus_o;
mul.f32 %result, %vg, %result;
st.global.f32 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const TANH_BACKWARD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry tanh_backward_kernel(
.param .u64 grad_ptr,
.param .u64 output_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %output, %out, %off;
.reg .f32 %vg, %vo, %one, %o_sq, %one_minus_sq, %result;
.reg .pred %p;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %output, [output_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %grad, %grad, %off;
add.u64 %output, %output, %off;
add.u64 %out, %out, %off;
ld.global.f32 %vg, [%grad];
ld.global.f32 %vo, [%output];
mov.f32 %one, 0f3F800000;
mul.f32 %o_sq, %vo, %vo;
sub.f32 %one_minus_sq, %one, %o_sq;
mul.f32 %result, %vg, %one_minus_sq;
st.global.f32 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SOFTMAX_BACKWARD_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.shared .align 4 .f32 sdata[256];\n\
\n\
.visible .entry softmax_backward_kernel(\n\
.param .u64 grad_ptr,\n\
.param .u64 output_ptr,\n\
.param .u64 out_ptr,\n\
.param .u32 rows,\n\
.param .u32 cols\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j, %half, %other_tid;\n\
.reg .u64 %grad, %output, %out, %row_off, %off, %sbase, %saddr;\n\
.reg .f32 %vg, %vo, %dot, %other_val, %diff, %result;\n\
.reg .pred %p, %loop_p, %reduce_p;\n\
\n\
ld.param.u64 %grad, [grad_ptr];\n\
ld.param.u64 %output, [output_ptr];\n\
ld.param.u64 %out, [out_ptr];\n\
ld.param.u32 %rows_reg, [rows];\n\
ld.param.u32 %cols_reg, [cols];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mov.u64 %sbase, sdata;\n\
\n\
setp.ge.u32 %p, %bid, %rows_reg;\n\
@%p bra DONE;\n\
\n\
// row_off = bid * cols * 4 (byte offset)\n\
cvt.u64.u32 %row_off, %bid;\n\
cvt.u64.u32 %off, %cols_reg;\n\
mul.lo.u64 %row_off, %row_off, %off;\n\
shl.b64 %row_off, %row_off, 2;\n\
\n\
// Phase 1: compute partial dot = sum(grad[j] * output[j]) for this thread's elements\n\
mov.f32 %dot, 0f00000000;\n\
mov.u32 %j, %r_tid;\n\
DOT_LOOP:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra DOT_LOOP_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %grad, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
ld.global.f32 %vg, [%saddr];\n\
add.u64 %saddr, %output, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
ld.global.f32 %vo, [%saddr];\n\
fma.rn.f32 %dot, %vg, %vo, %dot;\n\
add.u32 %j, %j, %bdim;\n\
bra DOT_LOOP;\n\
DOT_LOOP_DONE:\n\
\n\
// Store partial dot into shared memory and reduce\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f32 [%saddr], %dot;\n\
bar.sync 0;\n\
\n\
mov.u32 %half, %bdim;\n\
DOT_REDUCE:\n\
shr.u32 %half, %half, 1;\n\
setp.eq.u32 %reduce_p, %half, 0;\n\
@%reduce_p bra DOT_REDUCE_DONE;\n\
setp.ge.u32 %reduce_p, %r_tid, %half;\n\
@%reduce_p bra DOT_REDUCE_SKIP;\n\
add.u32 %other_tid, %r_tid, %half;\n\
cvt.u64.u32 %off, %other_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f32 %other_val, [%saddr];\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f32 %dot, [%saddr];\n\
add.f32 %dot, %dot, %other_val;\n\
st.shared.f32 [%saddr], %dot;\n\
DOT_REDUCE_SKIP:\n\
bar.sync 0;\n\
bra DOT_REDUCE;\n\
DOT_REDUCE_DONE:\n\
\n\
// Broadcast dot to all threads\n\
ld.shared.f32 %dot, [sdata];\n\
bar.sync 0;\n\
\n\
// Phase 2: out[j] = output[j] * (grad[j] - dot)\n\
mov.u32 %j, %r_tid;\n\
WRITE_LOOP:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra WRITE_LOOP_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %grad, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
ld.global.f32 %vg, [%saddr];\n\
add.u64 %saddr, %output, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
ld.global.f32 %vo, [%saddr];\n\
sub.f32 %diff, %vg, %dot;\n\
mul.f32 %result, %vo, %diff;\n\
add.u64 %saddr, %out, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
st.global.f32 [%saddr], %result;\n\
add.u32 %j, %j, %bdim;\n\
bra WRITE_LOOP;\n\
WRITE_LOOP_DONE:\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const LOG_SOFTMAX_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.shared .align 4 .f32 sdata[256];\n\
\n\
.visible .entry log_softmax_kernel(\n\
.param .u64 input_ptr,\n\
.param .u64 output_ptr,\n\
.param .u32 rows,\n\
.param .u32 cols\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
.reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
.reg .f32 %val, %max_val, %sum_val, %exp_val, %log_sum_exp, %result;\n\
.reg .pred %p, %loop_p;\n\
.reg .u32 %half, %other_tid;\n\
.reg .f32 %other_val;\n\
.reg .pred %reduce_p;\n\
\n\
ld.param.u64 %in, [input_ptr];\n\
ld.param.u64 %out, [output_ptr];\n\
ld.param.u32 %rows_reg, [rows];\n\
ld.param.u32 %cols_reg, [cols];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mov.u64 %sbase, sdata;\n\
\n\
setp.ge.u32 %p, %bid, %rows_reg;\n\
@%p bra DONE;\n\
\n\
// row_off = bid * cols * 4 (byte offset)\n\
cvt.u64.u32 %row_off, %bid;\n\
cvt.u64.u32 %off, %cols_reg;\n\
mul.lo.u64 %row_off, %row_off, %off;\n\
shl.b64 %row_off, %row_off, 2;\n\
\n\
// Phase 1: find max across row (grid-stride over columns)\n\
mov.f32 %max_val, 0fFF800000;\n\
mov.u32 %j, %r_tid;\n\
FIND_MAX:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra FIND_MAX_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %off, %in, %off;\n\
add.u64 %off, %off, %row_off;\n\
ld.global.f32 %val, [%off];\n\
max.f32 %max_val, %max_val, %val;\n\
add.u32 %j, %j, %bdim;\n\
bra FIND_MAX;\n\
FIND_MAX_DONE:\n\
\n\
// Shared-memory tree reduction for max\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f32 [%saddr], %max_val;\n\
bar.sync 0;\n\
\n\
mov.u32 %half, %bdim;\n\
MAX_REDUCE:\n\
shr.u32 %half, %half, 1;\n\
setp.eq.u32 %reduce_p, %half, 0;\n\
@%reduce_p bra MAX_REDUCE_DONE;\n\
setp.ge.u32 %reduce_p, %r_tid, %half;\n\
@%reduce_p bra MAX_REDUCE_SKIP;\n\
add.u32 %other_tid, %r_tid, %half;\n\
cvt.u64.u32 %off, %other_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f32 %other_val, [%saddr];\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f32 %max_val, [%saddr];\n\
max.f32 %max_val, %max_val, %other_val;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f32 [%saddr], %max_val;\n\
MAX_REDUCE_SKIP:\n\
bar.sync 0;\n\
bra MAX_REDUCE;\n\
MAX_REDUCE_DONE:\n\
\n\
// Broadcast max to all threads\n\
ld.shared.f32 %max_val, [sdata];\n\
bar.sync 0;\n\
\n\
// Phase 2: compute partial sum of exp(x[j] - max)\n\
mov.f32 %sum_val, 0f00000000;\n\
mov.u32 %j, %r_tid;\n\
SUM_EXP:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra SUM_EXP_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %off, %in, %off;\n\
add.u64 %off, %off, %row_off;\n\
ld.global.f32 %val, [%off];\n\
sub.f32 %val, %val, %max_val;\n\
// exp(x) = exp2(x * log2(e)), log2(e) = 0x3FB8AA3B\n\
mul.f32 %val, %val, 0f3FB8AA3B;\n\
ex2.approx.f32 %exp_val, %val;\n\
add.f32 %sum_val, %sum_val, %exp_val;\n\
add.u32 %j, %j, %bdim;\n\
bra SUM_EXP;\n\
SUM_EXP_DONE:\n\
\n\
// Shared-memory tree reduction for sum\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f32 [%saddr], %sum_val;\n\
bar.sync 0;\n\
\n\
mov.u32 %half, %bdim;\n\
SUM_REDUCE:\n\
shr.u32 %half, %half, 1;\n\
setp.eq.u32 %reduce_p, %half, 0;\n\
@%reduce_p bra SUM_REDUCE_DONE;\n\
setp.ge.u32 %reduce_p, %r_tid, %half;\n\
@%reduce_p bra SUM_REDUCE_SKIP;\n\
add.u32 %other_tid, %r_tid, %half;\n\
cvt.u64.u32 %off, %other_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f32 %other_val, [%saddr];\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f32 %sum_val, [%saddr];\n\
add.f32 %sum_val, %sum_val, %other_val;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f32 [%saddr], %sum_val;\n\
SUM_REDUCE_SKIP:\n\
bar.sync 0;\n\
bra SUM_REDUCE;\n\
SUM_REDUCE_DONE:\n\
\n\
// Broadcast sum to all threads, compute log_sum_exp = max + log(sum)\n\
ld.shared.f32 %sum_val, [sdata];\n\
bar.sync 0;\n\
// log(x) = log2(x) / log2(e) = log2(x) * ln(2)\n\
// ln(2) = 0x3F317218\n\
lg2.approx.f32 %log_sum_exp, %sum_val;\n\
mul.f32 %log_sum_exp, %log_sum_exp, 0f3F317218;\n\
add.f32 %log_sum_exp, %max_val, %log_sum_exp;\n\
\n\
// Phase 3: out[j] = x[j] - log_sum_exp\n\
mov.u32 %j, %r_tid;\n\
WRITE_OUTPUT:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra WRITE_OUTPUT_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %in, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
ld.global.f32 %val, [%saddr];\n\
sub.f32 %result, %val, %log_sum_exp;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %out, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
st.global.f32 [%saddr], %result;\n\
add.u32 %j, %j, %bdim;\n\
bra WRITE_OUTPUT;\n\
WRITE_OUTPUT_DONE:\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const LOG_SOFTMAX_F64_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.shared .align 8 .f64 sdata[256];\n\
\n\
.visible .entry log_softmax_f64_kernel(\n\
.param .u64 input_ptr,\n\
.param .u64 output_ptr,\n\
.param .u32 rows,\n\
.param .u32 cols\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
.reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
.reg .f64 %val, %max_val, %sum_val, %exp_val, %log_sum_exp, %result;\n\
.reg .pred %p, %loop_p;\n\
.reg .u32 %half, %other_tid;\n\
.reg .f64 %other_val;\n\
.reg .pred %reduce_p;\n\
.reg .f64 %e_nf, %e_r, %e_p, %e_half, %e_one;\n\
.reg .s32 %e_ni;\n\
.reg .s64 %e_ni64, %e_bits;\n\
.reg .u64 %l_xbits, %l_mbits, %l_bias;\n\
.reg .s64 %l_exp64;\n\
.reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf, %l_ln2;\n\
\n\
ld.param.u64 %in, [input_ptr];\n\
ld.param.u64 %out, [output_ptr];\n\
ld.param.u32 %rows_reg, [rows];\n\
ld.param.u32 %cols_reg, [cols];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mov.u64 %sbase, sdata;\n\
\n\
setp.ge.u32 %p, %bid, %rows_reg;\n\
@%p bra DONE;\n\
\n\
cvt.u64.u32 %row_off, %bid;\n\
cvt.u64.u32 %off, %cols_reg;\n\
mul.lo.u64 %row_off, %row_off, %off;\n\
shl.b64 %row_off, %row_off, 3;\n\
\n\
mov.f64 %max_val, 0dFFF0000000000000;\n\
mov.u32 %j, %r_tid;\n\
FIND_MAX:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra FIND_MAX_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %off, %in, %off;\n\
add.u64 %off, %off, %row_off;\n\
ld.global.f64 %val, [%off];\n\
max.f64 %max_val, %max_val, %val;\n\
add.u32 %j, %j, %bdim;\n\
bra FIND_MAX;\n\
FIND_MAX_DONE:\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f64 [%saddr], %max_val;\n\
bar.sync 0;\n\
\n\
mov.u32 %half, %bdim;\n\
MAX_REDUCE:\n\
shr.u32 %half, %half, 1;\n\
setp.eq.u32 %reduce_p, %half, 0;\n\
@%reduce_p bra MAX_REDUCE_DONE;\n\
setp.ge.u32 %reduce_p, %r_tid, %half;\n\
@%reduce_p bra MAX_REDUCE_SKIP;\n\
add.u32 %other_tid, %r_tid, %half;\n\
cvt.u64.u32 %off, %other_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f64 %other_val, [%saddr];\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f64 %max_val, [%saddr];\n\
max.f64 %max_val, %max_val, %other_val;\n\
st.shared.f64 [%saddr], %max_val;\n\
MAX_REDUCE_SKIP:\n\
bar.sync 0;\n\
bra MAX_REDUCE;\n\
MAX_REDUCE_DONE:\n\
\n\
ld.shared.f64 %max_val, [sdata];\n\
bar.sync 0;\n\
\n\
mov.f64 %sum_val, 0d0000000000000000;\n\
mov.u32 %j, %r_tid;\n\
SUM_EXP:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra SUM_EXP_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %off, %in, %off;\n\
add.u64 %off, %off, %row_off;\n\
ld.global.f64 %val, [%off];\n\
sub.f64 %val, %val, %max_val;\n\
mov.f64 %e_one, 0d3FF0000000000000;\n\
mov.f64 %e_half, 0d3FE0000000000000;\n\
mul.f64 %e_nf, %val, 0d3FF71547652B82FE;\n\
cvt.rni.f64.f64 %e_nf, %e_nf;\n\
cvt.rni.s32.f64 %e_ni, %e_nf;\n\
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %val;\n\
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;\n\
mov.f64 %e_p, 0d3E21EED8EFF8D898;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;\n\
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;\n\
fma.rn.f64 %e_p, %e_p, %e_r, %e_one;\n\
fma.rn.f64 %exp_val, %e_p, %e_r, %e_one;\n\
cvt.s64.s32 %e_ni64, %e_ni;\n\
add.s64 %e_ni64, %e_ni64, 1023;\n\
shl.b64 %e_bits, %e_ni64, 52;\n\
mov.b64 %e_nf, %e_bits;\n\
mul.f64 %exp_val, %exp_val, %e_nf;\n\
add.f64 %sum_val, %sum_val, %exp_val;\n\
add.u32 %j, %j, %bdim;\n\
bra SUM_EXP;\n\
SUM_EXP_DONE:\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f64 [%saddr], %sum_val;\n\
bar.sync 0;\n\
\n\
mov.u32 %half, %bdim;\n\
SUM_REDUCE:\n\
shr.u32 %half, %half, 1;\n\
setp.eq.u32 %reduce_p, %half, 0;\n\
@%reduce_p bra SUM_REDUCE_DONE;\n\
setp.ge.u32 %reduce_p, %r_tid, %half;\n\
@%reduce_p bra SUM_REDUCE_SKIP;\n\
add.u32 %other_tid, %r_tid, %half;\n\
cvt.u64.u32 %off, %other_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f64 %other_val, [%saddr];\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f64 %sum_val, [%saddr];\n\
add.f64 %sum_val, %sum_val, %other_val;\n\
st.shared.f64 [%saddr], %sum_val;\n\
SUM_REDUCE_SKIP:\n\
bar.sync 0;\n\
bra SUM_REDUCE;\n\
SUM_REDUCE_DONE:\n\
\n\
ld.shared.f64 %sum_val, [sdata];\n\
bar.sync 0;\n\
mov.f64 %e_one, 0d3FF0000000000000;\n\
mov.b64 %l_xbits, %sum_val;\n\
shr.u64 %l_exp64, %l_xbits, 52;\n\
and.b64 %l_exp64, %l_exp64, 2047;\n\
sub.s64 %l_exp64, %l_exp64, 1023;\n\
cvt.rn.f64.s64 %l_nf, %l_exp64;\n\
mov.u64 %l_bias, 0x3FF0000000000000;\n\
and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;\n\
or.b64 %l_mbits, %l_mbits, %l_bias;\n\
mov.b64 %l_m, %l_mbits;\n\
sub.f64 %l_f, %l_m, %e_one;\n\
add.f64 %l_s, %l_m, %e_one;\n\
div.rn.f64 %l_f, %l_f, %l_s;\n\
mul.f64 %l_f2, %l_f, %l_f;\n\
mov.f64 %l_p, 0d3FB745D1745D1746;\n\
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC1C71C71C71C72;\n\
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492;\n\
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A;\n\
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555;\n\
fma.rn.f64 %l_p, %l_p, %l_f2, %e_one;\n\
mul.f64 %l_p, %l_p, %l_f;\n\
add.f64 %l_p, %l_p, %l_p;\n\
mov.f64 %l_ln2, 0d3FE62E42FEFA39EF;\n\
fma.rn.f64 %log_sum_exp, %l_nf, %l_ln2, %l_p;\n\
add.f64 %log_sum_exp, %max_val, %log_sum_exp;\n\
\n\
mov.u32 %j, %r_tid;\n\
WRITE_OUTPUT:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra WRITE_OUTPUT_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %in, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
ld.global.f64 %val, [%saddr];\n\
sub.f64 %result, %val, %log_sum_exp;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %out, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
st.global.f64 [%saddr], %result;\n\
add.u32 %j, %j, %bdim;\n\
bra WRITE_OUTPUT;\n\
WRITE_OUTPUT_DONE:\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const LOG_SOFTMAX_BACKWARD_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.shared .align 4 .f32 sdata[256];\n\
\n\
.visible .entry log_softmax_backward_kernel(\n\
.param .u64 grad_ptr,\n\
.param .u64 output_ptr,\n\
.param .u64 out_ptr,\n\
.param .u32 rows,\n\
.param .u32 cols\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j, %half, %other_tid;\n\
.reg .u64 %grad, %output, %out, %row_off, %off, %sbase, %saddr;\n\
.reg .f32 %vg, %vo, %sum_grad, %other_val, %softmax_j, %result;\n\
.reg .pred %p, %loop_p, %reduce_p;\n\
\n\
ld.param.u64 %grad, [grad_ptr];\n\
ld.param.u64 %output, [output_ptr];\n\
ld.param.u64 %out, [out_ptr];\n\
ld.param.u32 %rows_reg, [rows];\n\
ld.param.u32 %cols_reg, [cols];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mov.u64 %sbase, sdata;\n\
\n\
setp.ge.u32 %p, %bid, %rows_reg;\n\
@%p bra DONE;\n\
\n\
// row_off = bid * cols * 4 (byte offset)\n\
cvt.u64.u32 %row_off, %bid;\n\
cvt.u64.u32 %off, %cols_reg;\n\
mul.lo.u64 %row_off, %row_off, %off;\n\
shl.b64 %row_off, %row_off, 2;\n\
\n\
// Phase 1: compute partial sum_grad = sum(grad[j]) for this thread's elements\n\
mov.f32 %sum_grad, 0f00000000;\n\
mov.u32 %j, %r_tid;\n\
SUM_LOOP:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra SUM_LOOP_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %grad, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
ld.global.f32 %vg, [%saddr];\n\
add.f32 %sum_grad, %sum_grad, %vg;\n\
add.u32 %j, %j, %bdim;\n\
bra SUM_LOOP;\n\
SUM_LOOP_DONE:\n\
\n\
// Store partial sum into shared memory and reduce\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f32 [%saddr], %sum_grad;\n\
bar.sync 0;\n\
\n\
mov.u32 %half, %bdim;\n\
SUM_REDUCE:\n\
shr.u32 %half, %half, 1;\n\
setp.eq.u32 %reduce_p, %half, 0;\n\
@%reduce_p bra SUM_REDUCE_DONE;\n\
setp.ge.u32 %reduce_p, %r_tid, %half;\n\
@%reduce_p bra SUM_REDUCE_SKIP;\n\
add.u32 %other_tid, %r_tid, %half;\n\
cvt.u64.u32 %off, %other_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f32 %other_val, [%saddr];\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f32 %sum_grad, [%saddr];\n\
add.f32 %sum_grad, %sum_grad, %other_val;\n\
st.shared.f32 [%saddr], %sum_grad;\n\
SUM_REDUCE_SKIP:\n\
bar.sync 0;\n\
bra SUM_REDUCE;\n\
SUM_REDUCE_DONE:\n\
\n\
// Broadcast sum_grad to all threads\n\
ld.shared.f32 %sum_grad, [sdata];\n\
bar.sync 0;\n\
\n\
// Phase 2: out[j] = grad[j] - exp(output[j]) * sum_grad\n\
mov.u32 %j, %r_tid;\n\
WRITE_LOOP:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra WRITE_LOOP_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %grad, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
ld.global.f32 %vg, [%saddr];\n\
add.u64 %saddr, %output, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
ld.global.f32 %vo, [%saddr];\n\
// exp(log_softmax_output) = softmax probability\n\
mul.f32 %vo, %vo, 0f3FB8AA3B;\n\
ex2.approx.f32 %softmax_j, %vo;\n\
// out[j] = grad[j] - softmax[j] * sum_grad\n\
mul.f32 %result, %softmax_j, %sum_grad;\n\
sub.f32 %result, %vg, %result;\n\
add.u64 %saddr, %out, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
st.global.f32 [%saddr], %result;\n\
add.u32 %j, %j, %bdim;\n\
bra WRITE_LOOP;\n\
WRITE_LOOP_DONE:\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const LOG_SOFTMAX_BACKWARD_F64_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.shared .align 8 .f64 sdata[256];\n\
\n\
.visible .entry log_softmax_backward_f64_kernel(\n\
.param .u64 grad_ptr,\n\
.param .u64 output_ptr,\n\
.param .u64 out_ptr,\n\
.param .u32 rows,\n\
.param .u32 cols\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j, %half, %other_tid;\n\
.reg .u64 %grad, %output, %out, %row_off, %off, %sbase, %saddr;\n\
.reg .f64 %vg, %vo, %sum_grad, %other_val, %softmax_j, %result;\n\
.reg .pred %p, %loop_p, %reduce_p;\n\
.reg .f64 %e_nf, %e_r, %e_p, %e_half, %e_one;\n\
.reg .s32 %e_ni;\n\
.reg .s64 %e_ni64, %e_bits;\n\
\n\
ld.param.u64 %grad, [grad_ptr];\n\
ld.param.u64 %output, [output_ptr];\n\
ld.param.u64 %out, [out_ptr];\n\
ld.param.u32 %rows_reg, [rows];\n\
ld.param.u32 %cols_reg, [cols];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mov.u64 %sbase, sdata;\n\
\n\
setp.ge.u32 %p, %bid, %rows_reg;\n\
@%p bra DONE;\n\
\n\
cvt.u64.u32 %row_off, %bid;\n\
cvt.u64.u32 %off, %cols_reg;\n\
mul.lo.u64 %row_off, %row_off, %off;\n\
shl.b64 %row_off, %row_off, 3;\n\
\n\
mov.f64 %sum_grad, 0d0000000000000000;\n\
mov.u32 %j, %r_tid;\n\
SUM_LOOP:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra SUM_LOOP_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %grad, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
ld.global.f64 %vg, [%saddr];\n\
add.f64 %sum_grad, %sum_grad, %vg;\n\
add.u32 %j, %j, %bdim;\n\
bra SUM_LOOP;\n\
SUM_LOOP_DONE:\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f64 [%saddr], %sum_grad;\n\
bar.sync 0;\n\
\n\
mov.u32 %half, %bdim;\n\
SUM_REDUCE:\n\
shr.u32 %half, %half, 1;\n\
setp.eq.u32 %reduce_p, %half, 0;\n\
@%reduce_p bra SUM_REDUCE_DONE;\n\
setp.ge.u32 %reduce_p, %r_tid, %half;\n\
@%reduce_p bra SUM_REDUCE_SKIP;\n\
add.u32 %other_tid, %r_tid, %half;\n\
cvt.u64.u32 %off, %other_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f64 %other_val, [%saddr];\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f64 %sum_grad, [%saddr];\n\
add.f64 %sum_grad, %sum_grad, %other_val;\n\
st.shared.f64 [%saddr], %sum_grad;\n\
SUM_REDUCE_SKIP:\n\
bar.sync 0;\n\
bra SUM_REDUCE;\n\
SUM_REDUCE_DONE:\n\
\n\
ld.shared.f64 %sum_grad, [sdata];\n\
bar.sync 0;\n\
\n\
mov.u32 %j, %r_tid;\n\
WRITE_LOOP:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra WRITE_LOOP_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %grad, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
ld.global.f64 %vg, [%saddr];\n\
add.u64 %saddr, %output, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
ld.global.f64 %vo, [%saddr];\n\
// exp(log_softmax_output) — inline f64 exp\n\
mov.f64 %e_one, 0d3FF0000000000000;\n\
mov.f64 %e_half, 0d3FE0000000000000;\n\
mul.f64 %e_nf, %vo, 0d3FF71547652B82FE;\n\
cvt.rni.f64.f64 %e_nf, %e_nf;\n\
cvt.rni.s32.f64 %e_ni, %e_nf;\n\
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %vo;\n\
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;\n\
mov.f64 %e_p, 0d3E21EED8EFF8D898;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;\n\
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;\n\
fma.rn.f64 %e_p, %e_p, %e_r, %e_one;\n\
fma.rn.f64 %softmax_j, %e_p, %e_r, %e_one;\n\
cvt.s64.s32 %e_ni64, %e_ni;\n\
add.s64 %e_ni64, %e_ni64, 1023;\n\
shl.b64 %e_bits, %e_ni64, 52;\n\
mov.b64 %e_nf, %e_bits;\n\
mul.f64 %softmax_j, %softmax_j, %e_nf;\n\
mul.f64 %result, %softmax_j, %sum_grad;\n\
sub.f64 %result, %vg, %result;\n\
add.u64 %saddr, %out, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
st.global.f64 [%saddr], %result;\n\
add.u32 %j, %j, %bdim;\n\
bra WRITE_LOOP;\n\
WRITE_LOOP_DONE:\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const REDUCE_SUM_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
// Shared memory for intra-block reduction (256 floats = 1024 bytes).
.shared .align 4 .f32 sdata[256];
.visible .entry reduce_sum_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %tid, %bid, %bdim, %gdim, %n_reg, %idx, %stride, %half;
.reg .u64 %in, %out, %off;
.reg .f32 %sum, %other;
.reg .pred %p, %ptid;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %tid, %tid.x;
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %gdim, %nctaid.x;
// Grid-stride accumulation: each thread sums multiple elements.
// idx = bid * bdim + tid; stride = bdim * gdim
mad.lo.u32 %idx, %bid, %bdim, %tid;
mul.lo.u32 %stride, %bdim, %gdim;
mov.f32 %sum, 0f00000000;
GRID_LOOP:
setp.ge.u32 %p, %idx, %n_reg;
@%p bra GRID_DONE;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %off, %in, %off;
ld.global.f32 %other, [%off];
add.f32 %sum, %sum, %other;
add.u32 %idx, %idx, %stride;
bra GRID_LOOP;
GRID_DONE:
// Write thread's partial sum to shared memory.
cvt.u64.u32 %off, %tid;
shl.b64 %off, %off, 2;
st.shared.f32 [sdata + %off], %sum;
bar.sync 0;
// Tree reduction in shared memory.
mov.u32 %half, 128;
TREE_LOOP:
setp.lt.u32 %p, %half, 1;
@%p bra TREE_DONE;
setp.ge.u32 %ptid, %tid, %half;
@%ptid bra TREE_SKIP;
// Load partner's value from sdata[tid + half].
add.u32 %idx, %tid, %half;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
ld.shared.f32 %other, [sdata + %off];
// Load own value.
cvt.u64.u32 %off, %tid;
shl.b64 %off, %off, 2;
ld.shared.f32 %sum, [sdata + %off];
add.f32 %sum, %sum, %other;
st.shared.f32 [sdata + %off], %sum;
TREE_SKIP:
bar.sync 0;
shr.u32 %half, %half, 1;
bra TREE_LOOP;
TREE_DONE:
// Thread 0 writes block result.
setp.ne.u32 %ptid, %tid, 0;
@%ptid bra END;
ld.shared.f32 %sum, [sdata];
cvt.u64.u32 %off, %bid;
shl.b64 %off, %off, 2;
add.u64 %out, %out, %off;
st.global.f32 [%out], %sum;
END:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SUM_AXIS_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry sum_axis_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u32 outer_size,
.param .u32 axis_size,
.param .u32 inner_size,
.param .u32 total_output
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %axis_sz, %inner_sz;
.reg .u32 %outer_idx, %inner_idx, %k, %tmp;
.reg .u64 %in, %out, %off, %addr;
.reg .f32 %val, %sum;
.reg .pred %p, %lp;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u32 %outer_sz, [outer_size];
ld.param.u32 %axis_sz, [axis_size];
ld.param.u32 %inner_sz, [inner_size];
ld.param.u32 %n_reg, [total_output];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
// outer_idx = r_tid / inner_size
div.u32 %outer_idx, %r_tid, %inner_sz;
// inner_idx = r_tid % inner_size
rem.u32 %inner_idx, %r_tid, %inner_sz;
// base = outer_idx * axis_size * inner_size + inner_idx
mul.lo.u32 %tmp, %outer_idx, %axis_sz;
mul.lo.u32 %tmp, %tmp, %inner_sz;
add.u32 %tmp, %tmp, %inner_idx;
mov.f32 %sum, 0f00000000;
mov.u32 %k, 0;
SUM_LOOP:
setp.ge.u32 %lp, %k, %axis_sz;
@%lp bra SUM_LOOP_DONE;
// addr = in + (tmp + k * inner_size) * 4
mul.lo.u32 %inner_idx, %k, %inner_sz;
add.u32 %inner_idx, %tmp, %inner_idx;
cvt.u64.u32 %off, %inner_idx;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
ld.global.f32 %val, [%addr];
add.f32 %sum, %sum, %val;
add.u32 %k, %k, 1;
bra SUM_LOOP;
SUM_LOOP_DONE:
// output[r_tid] = sum
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %addr, %out, %off;
st.global.f32 [%addr], %sum;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const CUMSUM_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry cumsum_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u32 outer_size,
.param .u32 dim_size,
.param .u32 inner_size,
.param .u32 total
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
.reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp;
.reg .u64 %in, %out, %off, %addr;
.reg .f32 %val, %acc;
.reg .pred %p, %lp;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u32 %outer_sz, [outer_size];
ld.param.u32 %dim_sz, [dim_size];
ld.param.u32 %inner_sz, [inner_size];
ld.param.u32 %n_reg, [total];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
// total threads = outer * inner
mul.lo.u32 %tmp, %outer_sz, %inner_sz;
setp.ge.u32 %p, %r_tid, %tmp;
@%p bra DONE;
div.u32 %outer_idx, %r_tid, %inner_sz;
rem.u32 %inner_idx, %r_tid, %inner_sz;
// base = outer_idx * dim_size * inner_size + inner_idx
mul.lo.u32 %base, %outer_idx, %dim_sz;
mul.lo.u32 %base, %base, %inner_sz;
add.u32 %base, %base, %inner_idx;
mov.f32 %acc, 0f00000000;
mov.u32 %k, 0;
SCAN_LOOP:
setp.ge.u32 %lp, %k, %dim_sz;
@%lp bra SCAN_DONE;
// idx = base + k * inner_size
mul.lo.u32 %idx, %k, %inner_sz;
add.u32 %idx, %base, %idx;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
ld.global.f32 %val, [%addr];
add.f32 %acc, %acc, %val;
add.u64 %addr, %out, %off;
st.global.f32 [%addr], %acc;
add.u32 %k, %k, 1;
bra SCAN_LOOP;
SCAN_DONE:
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const CUMPROD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry cumprod_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u32 outer_size,
.param .u32 dim_size,
.param .u32 inner_size,
.param .u32 total
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
.reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp;
.reg .u64 %in, %out, %off, %addr;
.reg .f32 %val, %acc;
.reg .pred %p, %lp;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u32 %outer_sz, [outer_size];
ld.param.u32 %dim_sz, [dim_size];
ld.param.u32 %inner_sz, [inner_size];
ld.param.u32 %n_reg, [total];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
mul.lo.u32 %tmp, %outer_sz, %inner_sz;
setp.ge.u32 %p, %r_tid, %tmp;
@%p bra DONE;
div.u32 %outer_idx, %r_tid, %inner_sz;
rem.u32 %inner_idx, %r_tid, %inner_sz;
mul.lo.u32 %base, %outer_idx, %dim_sz;
mul.lo.u32 %base, %base, %inner_sz;
add.u32 %base, %base, %inner_idx;
// acc = 1.0
mov.f32 %acc, 0f3F800000;
mov.u32 %k, 0;
SCAN_LOOP:
setp.ge.u32 %lp, %k, %dim_sz;
@%lp bra SCAN_DONE;
mul.lo.u32 %idx, %k, %inner_sz;
add.u32 %idx, %base, %idx;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
ld.global.f32 %val, [%addr];
mul.f32 %acc, %acc, %val;
add.u64 %addr, %out, %off;
st.global.f32 [%addr], %acc;
add.u32 %k, %k, 1;
bra SCAN_LOOP;
SCAN_DONE:
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const CUMMAX_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry cummax_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u64 indices_ptr,
.param .u32 outer_size,
.param .u32 dim_size,
.param .u32 inner_size,
.param .u32 total
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
.reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp, %best_k;
.reg .u64 %in, %out, %ind, %off, %addr;
.reg .f32 %val, %acc, %best_k_f;
.reg .pred %p, %lp, %is_new_max;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u64 %ind, [indices_ptr];
ld.param.u32 %outer_sz, [outer_size];
ld.param.u32 %dim_sz, [dim_size];
ld.param.u32 %inner_sz, [inner_size];
ld.param.u32 %n_reg, [total];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
mul.lo.u32 %tmp, %outer_sz, %inner_sz;
setp.ge.u32 %p, %r_tid, %tmp;
@%p bra DONE;
div.u32 %outer_idx, %r_tid, %inner_sz;
rem.u32 %inner_idx, %r_tid, %inner_sz;
mul.lo.u32 %base, %outer_idx, %dim_sz;
mul.lo.u32 %base, %base, %inner_sz;
add.u32 %base, %base, %inner_idx;
mov.b32 %acc, 0xFF800000;
mov.u32 %best_k, 0;
mov.u32 %k, 0;
SCAN_LOOP:
setp.ge.u32 %lp, %k, %dim_sz;
@%lp bra SCAN_DONE;
mul.lo.u32 %idx, %k, %inner_sz;
add.u32 %idx, %base, %idx;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
ld.global.f32 %val, [%addr];
setp.gt.f32 %is_new_max, %val, %acc;
@%is_new_max mov.u32 %best_k, %k;
max.f32 %acc, %acc, %val;
add.u64 %addr, %out, %off;
st.global.f32 [%addr], %acc;
cvt.rn.f32.u32 %best_k_f, %best_k;
add.u64 %addr, %ind, %off;
st.global.f32 [%addr], %best_k_f;
add.u32 %k, %k, 1;
bra SCAN_LOOP;
SCAN_DONE:
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const CUMMIN_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry cummin_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u64 indices_ptr,
.param .u32 outer_size,
.param .u32 dim_size,
.param .u32 inner_size,
.param .u32 total
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
.reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp, %best_k;
.reg .u64 %in, %out, %ind, %off, %addr;
.reg .f32 %val, %acc, %best_k_f;
.reg .pred %p, %lp, %is_new_min;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u64 %ind, [indices_ptr];
ld.param.u32 %outer_sz, [outer_size];
ld.param.u32 %dim_sz, [dim_size];
ld.param.u32 %inner_sz, [inner_size];
ld.param.u32 %n_reg, [total];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
mul.lo.u32 %tmp, %outer_sz, %inner_sz;
setp.ge.u32 %p, %r_tid, %tmp;
@%p bra DONE;
div.u32 %outer_idx, %r_tid, %inner_sz;
rem.u32 %inner_idx, %r_tid, %inner_sz;
mul.lo.u32 %base, %outer_idx, %dim_sz;
mul.lo.u32 %base, %base, %inner_sz;
add.u32 %base, %base, %inner_idx;
mov.b32 %acc, 0x7F800000;
mov.u32 %best_k, 0;
mov.u32 %k, 0;
SCAN_LOOP:
setp.ge.u32 %lp, %k, %dim_sz;
@%lp bra SCAN_DONE;
mul.lo.u32 %idx, %k, %inner_sz;
add.u32 %idx, %base, %idx;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
ld.global.f32 %val, [%addr];
setp.lt.f32 %is_new_min, %val, %acc;
@%is_new_min mov.u32 %best_k, %k;
min.f32 %acc, %acc, %val;
add.u64 %addr, %out, %off;
st.global.f32 [%addr], %acc;
cvt.rn.f32.u32 %best_k_f, %best_k;
add.u64 %addr, %ind, %off;
st.global.f32 [%addr], %best_k_f;
add.u32 %k, %k, 1;
bra SCAN_LOOP;
SCAN_DONE:
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const LOGCUMSUMEXP_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry logcumsumexp_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u32 outer_size,
.param .u32 dim_size,
.param .u32 inner_size,
.param .u32 total
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
.reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp;
.reg .u64 %in, %out, %off, %addr;
.reg .f32 %val, %acc, %m, %ea, %ev, %s, %ls, %log2e, %ln2;
.reg .pred %p, %lp;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u32 %outer_sz, [outer_size];
ld.param.u32 %dim_sz, [dim_size];
ld.param.u32 %inner_sz, [inner_size];
ld.param.u32 %n_reg, [total];
// log2(e) = 1.4426950408... -> 0x3FB8AA3B
mov.b32 %log2e, 0x3FB8AA3B;
// ln(2) = 0.6931471805... -> 0x3F317218
mov.b32 %ln2, 0x3F317218;
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
mul.lo.u32 %tmp, %outer_sz, %inner_sz;
setp.ge.u32 %p, %r_tid, %tmp;
@%p bra DONE;
div.u32 %outer_idx, %r_tid, %inner_sz;
rem.u32 %inner_idx, %r_tid, %inner_sz;
mul.lo.u32 %base, %outer_idx, %dim_sz;
mul.lo.u32 %base, %base, %inner_sz;
add.u32 %base, %base, %inner_idx;
// acc = -inf
mov.b32 %acc, 0xFF800000;
mov.u32 %k, 0;
SCAN_LOOP:
setp.ge.u32 %lp, %k, %dim_sz;
@%lp bra SCAN_DONE;
mul.lo.u32 %idx, %k, %inner_sz;
add.u32 %idx, %base, %idx;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
ld.global.f32 %val, [%addr];
// Numerically stable: m = max(acc, x)
max.f32 %m, %acc, %val;
// exp(acc - m): (acc - m) * log2(e) -> ex2
sub.f32 %ea, %acc, %m;
mul.f32 %ea, %ea, %log2e;
ex2.approx.f32 %ea, %ea;
// exp(x - m): (x - m) * log2(e) -> ex2
sub.f32 %ev, %val, %m;
mul.f32 %ev, %ev, %log2e;
ex2.approx.f32 %ev, %ev;
// sum
add.f32 %s, %ea, %ev;
// log(sum) = lg2(sum) * ln(2)
lg2.approx.f32 %ls, %s;
mul.f32 %ls, %ls, %ln2;
// acc = m + log(sum)
add.f32 %acc, %m, %ls;
add.u64 %addr, %out, %off;
st.global.f32 [%addr], %acc;
add.u32 %k, %k, 1;
bra SCAN_LOOP;
SCAN_DONE:
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const LOGCUMSUMEXP_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry logcumsumexp_f64_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u32 outer_size,
.param .u32 dim_size,
.param .u32 inner_size,
.param .u32 total
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
.reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp;
.reg .u64 %in, %out, %off, %addr;
.reg .f64 %val, %acc, %m, %ea, %ev, %s, %ls;
.reg .pred %p, %lp;
.reg .f64 %e_nf, %e_r, %e_p, %e_half, %e_one;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
.reg .u64 %l_xbits, %l_mbits, %l_bias;
.reg .s64 %l_exp64;
.reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf, %l_ln2;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u32 %outer_sz, [outer_size];
ld.param.u32 %dim_sz, [dim_size];
ld.param.u32 %inner_sz, [inner_size];
ld.param.u32 %n_reg, [total];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
mul.lo.u32 %tmp, %outer_sz, %inner_sz;
setp.ge.u32 %p, %r_tid, %tmp;
@%p bra DONE;
div.u32 %outer_idx, %r_tid, %inner_sz;
rem.u32 %inner_idx, %r_tid, %inner_sz;
mul.lo.u32 %base, %outer_idx, %dim_sz;
mul.lo.u32 %base, %base, %inner_sz;
add.u32 %base, %base, %inner_idx;
// acc = -inf
mov.b64 %acc, 0xFFF0000000000000;
mov.u32 %k, 0;
SCAN_LOOP:
setp.ge.u32 %lp, %k, %dim_sz;
@%lp bra SCAN_DONE;
mul.lo.u32 %idx, %k, %inner_sz;
add.u32 %idx, %base, %idx;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 3;
add.u64 %addr, %in, %off;
ld.global.f64 %val, [%addr];
max.f64 %m, %acc, %val;
mov.f64 %e_one, 0d3FF0000000000000;
mov.f64 %e_half, 0d3FE0000000000000;
// --- inline exp(acc - m) -> %ea ---
sub.f64 %ea, %acc, %m;
mul.f64 %e_nf, %ea, 0d3FF71547652B82FE;
cvt.rni.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %ea;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E21EED8EFF8D898;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %e_one;
fma.rn.f64 %ea, %e_p, %e_r, %e_one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %ea, %ea, %e_nf;
// --- inline exp(val - m) -> %ev ---
sub.f64 %ev, %val, %m;
mul.f64 %e_nf, %ev, 0d3FF71547652B82FE;
cvt.rni.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %ev;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E21EED8EFF8D898;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %e_one;
fma.rn.f64 %ev, %e_p, %e_r, %e_one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %ev, %ev, %e_nf;
add.f64 %s, %ea, %ev;
// --- inline ln(%s) -> %ls ---
mov.b64 %l_xbits, %s;
shr.u64 %l_exp64, %l_xbits, 52;
and.b64 %l_exp64, %l_exp64, 2047;
sub.s64 %l_exp64, %l_exp64, 1023;
cvt.rn.f64.s64 %l_nf, %l_exp64;
mov.u64 %l_bias, 0x3FF0000000000000;
and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;
or.b64 %l_mbits, %l_mbits, %l_bias;
mov.b64 %l_m, %l_mbits;
sub.f64 %l_f, %l_m, %e_one;
add.f64 %l_s, %l_m, %e_one;
div.rn.f64 %l_f, %l_f, %l_s;
mul.f64 %l_f2, %l_f, %l_f;
mov.f64 %l_p, 0d3FB745D1745D1746;
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC1C71C71C71C72;
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492;
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A;
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555;
fma.rn.f64 %l_p, %l_p, %l_f2, %e_one;
mul.f64 %l_p, %l_p, %l_f;
add.f64 %l_p, %l_p, %l_p;
mov.f64 %l_ln2, 0d3FE62E42FEFA39EF;
fma.rn.f64 %ls, %l_nf, %l_ln2, %l_p;
add.f64 %acc, %m, %ls;
add.u64 %addr, %out, %off;
st.global.f64 [%addr], %acc;
add.u32 %k, %k, 1;
bra SCAN_LOOP;
SCAN_DONE:
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const LAYERNORM_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.shared .align 4 .f32 sdata[256];
.visible .entry layernorm_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u64 w_ptr,
.param .u64 b_ptr,
.param .u32 rows,
.param .u32 cols,
.param .f32 eps
) {
.reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
.reg .u64 %in, %out, %w, %b, %row_off, %off, %sbase, %saddr;
.reg .f32 %val, %mean, %var, %diff, %eps_r, %inv_std, %normed, %wv, %bv, %result, %other_val, %n_f;
.reg .pred %p, %lp, %rp;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u64 %w, [w_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u32 %rows_reg, [rows];
ld.param.u32 %cols_reg, [cols];
ld.param.f32 %eps_r, [eps];
mov.u64 %sbase, sdata;
mov.u32 %r_bid, %ctaid.x;
mov.u32 %r_bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
setp.ge.u32 %p, %r_bid, %rows_reg;
@%p bra DONE;
cvt.u64.u32 %row_off, %r_bid;
cvt.u64.u32 %off, %cols_reg;
mul.lo.u64 %row_off, %row_off, %off;
shl.b64 %row_off, %row_off, 2;
cvt.rn.f32.u32 %n_f, %cols_reg;
mov.f32 %mean, 0f00000000;
mov.u32 %j, %r_tid;
SM:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra SMD;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %off, %in, %off;
add.u64 %off, %off, %row_off;
ld.global.f32 %val, [%off];
add.f32 %mean, %mean, %val;
add.u32 %j, %j, %r_bdim;
bra SM;
SMD:
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %mean;
bar.sync 0;
mov.u32 %half, %r_bdim;
MR:
shr.u32 %half, %half, 1;
setp.eq.u32 %rp, %half, 0;
@%rp bra MRD;
setp.ge.u32 %rp, %r_tid, %half;
@%rp bra MRS;
add.u32 %r_otid, %r_tid, %half;
cvt.u64.u32 %off, %r_otid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %other_val, [%saddr];
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %mean, [%saddr];
add.f32 %mean, %mean, %other_val;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %mean;
MRS:
bar.sync 0;
bra MR;
MRD:
ld.shared.f32 %mean, [%sbase];
div.approx.f32 %mean, %mean, %n_f;
bar.sync 0;
mov.f32 %var, 0f00000000;
mov.u32 %j, %r_tid;
SV:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra SVD;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %off, %in, %off;
add.u64 %off, %off, %row_off;
ld.global.f32 %val, [%off];
sub.f32 %diff, %val, %mean;
fma.rn.f32 %var, %diff, %diff, %var;
add.u32 %j, %j, %r_bdim;
bra SV;
SVD:
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %var;
bar.sync 0;
mov.u32 %half, %r_bdim;
VR:
shr.u32 %half, %half, 1;
setp.eq.u32 %rp, %half, 0;
@%rp bra VRD;
setp.ge.u32 %rp, %r_tid, %half;
@%rp bra VRS;
add.u32 %r_otid, %r_tid, %half;
cvt.u64.u32 %off, %r_otid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %other_val, [%saddr];
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %var, [%saddr];
add.f32 %var, %var, %other_val;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %var;
VRS:
bar.sync 0;
bra VR;
VRD:
ld.shared.f32 %var, [%sbase];
div.approx.f32 %var, %var, %n_f;
add.f32 %var, %var, %eps_r;
sqrt.approx.f32 %inv_std, %var;
rcp.approx.f32 %inv_std, %inv_std;
bar.sync 0;
mov.u32 %j, %r_tid;
NM:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra NMD;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %off, %in, %off;
add.u64 %off, %off, %row_off;
ld.global.f32 %val, [%off];
sub.f32 %normed, %val, %mean;
mul.f32 %normed, %normed, %inv_std;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %off, %w, %off;
ld.global.f32 %wv, [%off];
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %off, %b, %off;
ld.global.f32 %bv, [%off];
fma.rn.f32 %result, %wv, %normed, %bv;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %off, %out, %off;
add.u64 %off, %off, %row_off;
st.global.f32 [%off], %result;
add.u32 %j, %j, %r_bdim;
bra NM;
NMD:
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const LAYERNORM_BACKWARD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.shared .align 4 .f32 sdata[256];
.visible .entry layernorm_backward_kernel(
.param .u64 in_ptr,
.param .u64 grad_out_ptr,
.param .u64 w_ptr,
.param .u64 grad_in_ptr,
.param .u64 grad_w_ptr,
.param .u64 grad_b_ptr,
.param .u32 rows,
.param .u32 cols,
.param .f32 eps
) {
.reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
.reg .u64 %in, %go, %w, %gi, %gw, %gb, %row_off, %off, %sbase, %saddr, %addr;
.reg .f32 %val, %mean, %var, %diff, %eps_r, %inv_std, %x_hat, %wv, %gov;
.reg .f32 %dl_dx_hat, %sum1, %sum2, %other_val, %n_f, %mean1, %mean2, %result;
.reg .pred %p, %lp, %rp;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %go, [grad_out_ptr];
ld.param.u64 %w, [w_ptr];
ld.param.u64 %gi, [grad_in_ptr];
ld.param.u64 %gw, [grad_w_ptr];
ld.param.u64 %gb, [grad_b_ptr];
ld.param.u32 %rows_reg, [rows];
ld.param.u32 %cols_reg, [cols];
ld.param.f32 %eps_r, [eps];
mov.u64 %sbase, sdata;
mov.u32 %r_bid, %ctaid.x;
mov.u32 %r_bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
setp.ge.u32 %p, %r_bid, %rows_reg;
@%p bra LNB_DONE;
// row_off = bid * cols * 4 (byte offset for this row)
cvt.u64.u32 %row_off, %r_bid;
cvt.u64.u32 %off, %cols_reg;
mul.lo.u64 %row_off, %row_off, %off;
shl.b64 %row_off, %row_off, 2;
cvt.rn.f32.u32 %n_f, %cols_reg;
// ===== Phase 1: Compute mean =====
mov.f32 %mean, 0f00000000;
mov.u32 %j, %r_tid;
LNB_SM:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra LNB_SMD;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
add.u64 %addr, %addr, %row_off;
ld.global.f32 %val, [%addr];
add.f32 %mean, %mean, %val;
add.u32 %j, %j, %r_bdim;
bra LNB_SM;
LNB_SMD:
// Shared memory reduce for mean
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %mean;
bar.sync 0;
mov.u32 %half, %r_bdim;
LNB_MR:
shr.u32 %half, %half, 1;
setp.eq.u32 %rp, %half, 0;
@%rp bra LNB_MRD;
setp.ge.u32 %rp, %r_tid, %half;
@%rp bra LNB_MRS;
add.u32 %r_otid, %r_tid, %half;
cvt.u64.u32 %off, %r_otid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %other_val, [%saddr];
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %mean, [%saddr];
add.f32 %mean, %mean, %other_val;
st.shared.f32 [%saddr], %mean;
LNB_MRS:
bar.sync 0;
bra LNB_MR;
LNB_MRD:
ld.shared.f32 %mean, [%sbase];
div.approx.f32 %mean, %mean, %n_f;
bar.sync 0;
// ===== Phase 2: Compute variance =====
mov.f32 %var, 0f00000000;
mov.u32 %j, %r_tid;
LNB_SV:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra LNB_SVD;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
add.u64 %addr, %addr, %row_off;
ld.global.f32 %val, [%addr];
sub.f32 %diff, %val, %mean;
fma.rn.f32 %var, %diff, %diff, %var;
add.u32 %j, %j, %r_bdim;
bra LNB_SV;
LNB_SVD:
// Shared memory reduce for variance
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %var;
bar.sync 0;
mov.u32 %half, %r_bdim;
LNB_VR:
shr.u32 %half, %half, 1;
setp.eq.u32 %rp, %half, 0;
@%rp bra LNB_VRD;
setp.ge.u32 %rp, %r_tid, %half;
@%rp bra LNB_VRS;
add.u32 %r_otid, %r_tid, %half;
cvt.u64.u32 %off, %r_otid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %other_val, [%saddr];
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %var, [%saddr];
add.f32 %var, %var, %other_val;
st.shared.f32 [%saddr], %var;
LNB_VRS:
bar.sync 0;
bra LNB_VR;
LNB_VRD:
ld.shared.f32 %var, [%sbase];
div.approx.f32 %var, %var, %n_f;
add.f32 %var, %var, %eps_r;
sqrt.approx.f32 %inv_std, %var;
rcp.approx.f32 %inv_std, %inv_std;
bar.sync 0;
// ===== Phase 3: Compute sum1 = sum(dl_dx_hat), sum2 = sum(dl_dx_hat * x_hat) =====
// Also accumulate grad_weight and grad_bias via atomicAdd
mov.f32 %sum1, 0f00000000;
mov.f32 %sum2, 0f00000000;
mov.u32 %j, %r_tid;
LNB_S12:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra LNB_S12D;
// Load input[row, j]
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
add.u64 %addr, %addr, %row_off;
ld.global.f32 %val, [%addr];
// x_hat = (val - mean) * inv_std
sub.f32 %x_hat, %val, %mean;
mul.f32 %x_hat, %x_hat, %inv_std;
// Load grad_output[row, j]
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %go, %off;
add.u64 %addr, %addr, %row_off;
ld.global.f32 %gov, [%addr];
// Load weight[j]
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %w, %off;
ld.global.f32 %wv, [%addr];
// dl_dx_hat = grad_output * weight
mul.f32 %dl_dx_hat, %gov, %wv;
// Accumulate sums
add.f32 %sum1, %sum1, %dl_dx_hat;
fma.rn.f32 %sum2, %dl_dx_hat, %x_hat, %sum2;
// atomicAdd grad_weight[j] += grad_output * x_hat
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %gw, %off;
mul.f32 %result, %gov, %x_hat;
atom.global.add.f32 %result, [%addr], %result;
// atomicAdd grad_bias[j] += grad_output
add.u64 %addr, %gb, %off;
atom.global.add.f32 %result, [%addr], %gov;
add.u32 %j, %j, %r_bdim;
bra LNB_S12;
LNB_S12D:
// Reduce sum1 in shared memory
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %sum1;
bar.sync 0;
mov.u32 %half, %r_bdim;
LNB_R1:
shr.u32 %half, %half, 1;
setp.eq.u32 %rp, %half, 0;
@%rp bra LNB_R1D;
setp.ge.u32 %rp, %r_tid, %half;
@%rp bra LNB_R1S;
add.u32 %r_otid, %r_tid, %half;
cvt.u64.u32 %off, %r_otid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %other_val, [%saddr];
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %sum1, [%saddr];
add.f32 %sum1, %sum1, %other_val;
st.shared.f32 [%saddr], %sum1;
LNB_R1S:
bar.sync 0;
bra LNB_R1;
LNB_R1D:
ld.shared.f32 %sum1, [%sbase];
// mean1 = sum1 / n
div.approx.f32 %mean1, %sum1, %n_f;
bar.sync 0;
// Reduce sum2 in shared memory
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %sum2;
bar.sync 0;
mov.u32 %half, %r_bdim;
LNB_R2:
shr.u32 %half, %half, 1;
setp.eq.u32 %rp, %half, 0;
@%rp bra LNB_R2D;
setp.ge.u32 %rp, %r_tid, %half;
@%rp bra LNB_R2S;
add.u32 %r_otid, %r_tid, %half;
cvt.u64.u32 %off, %r_otid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %other_val, [%saddr];
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %sum2, [%saddr];
add.f32 %sum2, %sum2, %other_val;
st.shared.f32 [%saddr], %sum2;
LNB_R2S:
bar.sync 0;
bra LNB_R2;
LNB_R2D:
ld.shared.f32 %sum2, [%sbase];
// mean2 = sum2 / n
div.approx.f32 %mean2, %sum2, %n_f;
bar.sync 0;
// ===== Phase 4: Compute grad_input =====
// grad_input[j] = inv_std * (dl_dx_hat[j] - mean1 - x_hat[j] * mean2)
mov.u32 %j, %r_tid;
LNB_GI:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra LNB_GID;
// Reload input to recompute x_hat
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
add.u64 %addr, %addr, %row_off;
ld.global.f32 %val, [%addr];
sub.f32 %x_hat, %val, %mean;
mul.f32 %x_hat, %x_hat, %inv_std;
// Reload grad_output and weight to recompute dl_dx_hat
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %go, %off;
add.u64 %addr, %addr, %row_off;
ld.global.f32 %gov, [%addr];
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %w, %off;
ld.global.f32 %wv, [%addr];
mul.f32 %dl_dx_hat, %gov, %wv;
// result = inv_std * (dl_dx_hat - mean1 - x_hat * mean2)
sub.f32 %result, %dl_dx_hat, %mean1;
mul.f32 %diff, %x_hat, %mean2;
sub.f32 %result, %result, %diff;
mul.f32 %result, %inv_std, %result;
// Store grad_input[row, j]
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %gi, %off;
add.u64 %addr, %addr, %row_off;
st.global.f32 [%addr], %result;
add.u32 %j, %j, %r_bdim;
bra LNB_GI;
LNB_GID:
LNB_DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const RMSNORM_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.shared .align 4 .f32 sdata[256];
.visible .entry rmsnorm_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u64 w_ptr,
.param .u32 rows,
.param .u32 cols,
.param .f32 eps
) {
.reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
.reg .u64 %in, %out, %w, %row_off, %off, %sbase, %saddr;
.reg .f32 %val, %sq_sum, %eps_r, %inv_rms, %wv, %result, %other_val, %n_f;
.reg .pred %p, %lp, %rp;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u64 %w, [w_ptr];
ld.param.u32 %rows_reg, [rows];
ld.param.u32 %cols_reg, [cols];
ld.param.f32 %eps_r, [eps];
mov.u64 %sbase, sdata;
mov.u32 %r_bid, %ctaid.x;
mov.u32 %r_bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
setp.ge.u32 %p, %r_bid, %rows_reg;
@%p bra DONE;
cvt.u64.u32 %row_off, %r_bid;
cvt.u64.u32 %off, %cols_reg;
mul.lo.u64 %row_off, %row_off, %off;
shl.b64 %row_off, %row_off, 2;
cvt.rn.f32.u32 %n_f, %cols_reg;
// ===== Phase 1: Compute sum(x^2) =====
mov.f32 %sq_sum, 0f00000000;
mov.u32 %j, %r_tid;
SS:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra SSD;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %off, %in, %off;
add.u64 %off, %off, %row_off;
ld.global.f32 %val, [%off];
fma.rn.f32 %sq_sum, %val, %val, %sq_sum;
add.u32 %j, %j, %r_bdim;
bra SS;
SSD:
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %sq_sum;
bar.sync 0;
mov.u32 %half, %r_bdim;
SR:
shr.u32 %half, %half, 1;
setp.eq.u32 %rp, %half, 0;
@%rp bra SRD;
setp.ge.u32 %rp, %r_tid, %half;
@%rp bra SRS;
add.u32 %r_otid, %r_tid, %half;
cvt.u64.u32 %off, %r_otid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %other_val, [%saddr];
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %sq_sum, [%saddr];
add.f32 %sq_sum, %sq_sum, %other_val;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %sq_sum;
SRS:
bar.sync 0;
bra SR;
SRD:
ld.shared.f32 %sq_sum, [%sbase];
div.approx.f32 %sq_sum, %sq_sum, %n_f;
add.f32 %sq_sum, %sq_sum, %eps_r;
sqrt.approx.f32 %inv_rms, %sq_sum;
rcp.approx.f32 %inv_rms, %inv_rms;
bar.sync 0;
// ===== Phase 2: Normalize and scale =====
// out[j] = x[j] * inv_rms * weight[j]
mov.u32 %j, %r_tid;
NM:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra NMD;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %off, %in, %off;
add.u64 %off, %off, %row_off;
ld.global.f32 %val, [%off];
mul.f32 %result, %val, %inv_rms;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %off, %w, %off;
ld.global.f32 %wv, [%off];
mul.f32 %result, %result, %wv;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %off, %out, %off;
add.u64 %off, %off, %row_off;
st.global.f32 [%off], %result;
add.u32 %j, %j, %r_bdim;
bra NM;
NMD:
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const RMSNORM_BACKWARD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.shared .align 4 .f32 sdata[256];
.visible .entry rmsnorm_backward_kernel(
.param .u64 in_ptr,
.param .u64 grad_out_ptr,
.param .u64 w_ptr,
.param .u64 grad_in_ptr,
.param .u64 grad_w_ptr,
.param .u32 rows,
.param .u32 cols,
.param .f32 eps
) {
.reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
.reg .u64 %in, %go, %w, %gi, %gw, %row_off, %off, %sbase, %saddr, %addr;
.reg .f32 %val, %sq_sum, %eps_r, %inv_rms, %inv_rms3, %wv, %gov;
.reg .f32 %dot, %other_val, %n_f, %coeff, %result, %tmp;
.reg .pred %p, %lp, %rp;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %go, [grad_out_ptr];
ld.param.u64 %w, [w_ptr];
ld.param.u64 %gi, [grad_in_ptr];
ld.param.u64 %gw, [grad_w_ptr];
ld.param.u32 %rows_reg, [rows];
ld.param.u32 %cols_reg, [cols];
ld.param.f32 %eps_r, [eps];
mov.u64 %sbase, sdata;
mov.u32 %r_bid, %ctaid.x;
mov.u32 %r_bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
setp.ge.u32 %p, %r_bid, %rows_reg;
@%p bra RNB_DONE;
// row_off = bid * cols * 4 (byte offset for this row)
cvt.u64.u32 %row_off, %r_bid;
cvt.u64.u32 %off, %cols_reg;
mul.lo.u64 %row_off, %row_off, %off;
shl.b64 %row_off, %row_off, 2;
cvt.rn.f32.u32 %n_f, %cols_reg;
// ===== Phase 1: Compute sum(x^2) -> inv_rms =====
mov.f32 %sq_sum, 0f00000000;
mov.u32 %j, %r_tid;
RNB_SS:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra RNB_SSD;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
add.u64 %addr, %addr, %row_off;
ld.global.f32 %val, [%addr];
fma.rn.f32 %sq_sum, %val, %val, %sq_sum;
add.u32 %j, %j, %r_bdim;
bra RNB_SS;
RNB_SSD:
// Shared memory reduce for sum(x^2)
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %sq_sum;
bar.sync 0;
mov.u32 %half, %r_bdim;
RNB_SR:
shr.u32 %half, %half, 1;
setp.eq.u32 %rp, %half, 0;
@%rp bra RNB_SRD;
setp.ge.u32 %rp, %r_tid, %half;
@%rp bra RNB_SRS;
add.u32 %r_otid, %r_tid, %half;
cvt.u64.u32 %off, %r_otid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %other_val, [%saddr];
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %sq_sum, [%saddr];
add.f32 %sq_sum, %sq_sum, %other_val;
st.shared.f32 [%saddr], %sq_sum;
RNB_SRS:
bar.sync 0;
bra RNB_SR;
RNB_SRD:
ld.shared.f32 %sq_sum, [%sbase];
div.approx.f32 %sq_sum, %sq_sum, %n_f;
add.f32 %sq_sum, %sq_sum, %eps_r;
sqrt.approx.f32 %inv_rms, %sq_sum;
rcp.approx.f32 %inv_rms, %inv_rms;
// inv_rms3 = inv_rms^3 = inv_rms * inv_rms * inv_rms
mul.f32 %inv_rms3, %inv_rms, %inv_rms;
mul.f32 %inv_rms3, %inv_rms3, %inv_rms;
bar.sync 0;
// ===== Phase 2: Compute dot = sum(go[j] * x[j] * w[j]) =====
// Also accumulate grad_weight via atomicAdd
mov.f32 %dot, 0f00000000;
mov.u32 %j, %r_tid;
RNB_DOT:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra RNB_DOTD;
// Load input[row, j]
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
add.u64 %addr, %addr, %row_off;
ld.global.f32 %val, [%addr];
// Load grad_output[row, j]
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %go, %off;
add.u64 %addr, %addr, %row_off;
ld.global.f32 %gov, [%addr];
// Load weight[j]
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %w, %off;
ld.global.f32 %wv, [%addr];
// dot += go * x * w
mul.f32 %tmp, %gov, %val;
fma.rn.f32 %dot, %tmp, %wv, %dot;
// atomicAdd grad_weight[j] += go * x * inv_rms
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %gw, %off;
mul.f32 %result, %gov, %val;
mul.f32 %result, %result, %inv_rms;
atom.global.add.f32 %result, [%addr], %result;
add.u32 %j, %j, %r_bdim;
bra RNB_DOT;
RNB_DOTD:
// Reduce dot in shared memory
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %dot;
bar.sync 0;
mov.u32 %half, %r_bdim;
RNB_DR:
shr.u32 %half, %half, 1;
setp.eq.u32 %rp, %half, 0;
@%rp bra RNB_DRD;
setp.ge.u32 %rp, %r_tid, %half;
@%rp bra RNB_DRS;
add.u32 %r_otid, %r_tid, %half;
cvt.u64.u32 %off, %r_otid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %other_val, [%saddr];
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %dot, [%saddr];
add.f32 %dot, %dot, %other_val;
st.shared.f32 [%saddr], %dot;
RNB_DRS:
bar.sync 0;
bra RNB_DR;
RNB_DRD:
ld.shared.f32 %dot, [%sbase];
// coeff = dot * inv_rms3 / n
mul.f32 %coeff, %dot, %inv_rms3;
div.approx.f32 %coeff, %coeff, %n_f;
bar.sync 0;
// ===== Phase 3: Compute grad_input =====
// grad_input[j] = inv_rms * w[j] * go[j] - x[j] * coeff
mov.u32 %j, %r_tid;
RNB_GI:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra RNB_GID;
// Reload input
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
add.u64 %addr, %addr, %row_off;
ld.global.f32 %val, [%addr];
// Reload grad_output and weight
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %go, %off;
add.u64 %addr, %addr, %row_off;
ld.global.f32 %gov, [%addr];
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %w, %off;
ld.global.f32 %wv, [%addr];
// result = inv_rms * w * go - x * coeff
mul.f32 %result, %inv_rms, %wv;
mul.f32 %result, %result, %gov;
mul.f32 %tmp, %val, %coeff;
sub.f32 %result, %result, %tmp;
// Store grad_input[row, j]
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %gi, %off;
add.u64 %addr, %addr, %row_off;
st.global.f32 [%addr], %result;
add.u32 %j, %j, %r_bdim;
bra RNB_GI;
RNB_GID:
RNB_DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const BATCHNORM_FORWARD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
// Shared memory for block reduction
.shared .align 4 .f32 smem_sum[256];
.shared .align 4 .f32 smem_sq[256];
.visible .entry batchnorm_forward_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u64 weight_ptr,
.param .u64 bias_ptr,
.param .u64 rmean_ptr,
.param .u64 rvar_ptr,
.param .u64 save_mean_ptr,
.param .u64 save_invstd_ptr,
.param .u32 channels,
.param .u32 spatial,
.param .f32 eps,
.param .f32 momentum,
.param .u32 total_per_ch,
.param .u32 training
) {
.reg .u32 %tid, %bid, %bdim, %ch, %n_ch, %sp, %tpc, %idx, %train;
.reg .u64 %in, %out, %w, %b, %rm, %rv, %sm, %si, %off64, %tmp64;
.reg .f32 %sum, %sqsum, %val, %mean, %var, %invstd;
.reg .f32 %gamma, %beta, %eps_reg, %mom, %other;
.reg .f32 %n_f, %one, %normalized;
.reg .pred %p, %ptrain, %ptid0;
.reg .u32 %half;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u64 %w, [weight_ptr];
ld.param.u64 %b, [bias_ptr];
ld.param.u64 %rm, [rmean_ptr];
ld.param.u64 %rv, [rvar_ptr];
ld.param.u64 %sm, [save_mean_ptr];
ld.param.u64 %si, [save_invstd_ptr];
ld.param.u32 %n_ch, [channels];
ld.param.u32 %sp, [spatial];
ld.param.f32 %eps_reg, [eps];
ld.param.f32 %mom, [momentum];
ld.param.u32 %tpc, [total_per_ch];
ld.param.u32 %train, [training];
mov.u32 %bid, %ctaid.x;
mov.u32 %tid, %tid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %ch, %bid;
mov.f32 %one, 0f3F800000;
setp.ge.u32 %p, %ch, %n_ch;
@%p bra END;
setp.ne.u32 %ptrain, %train, 0;
// ---- Pass 1: compute sum and sum-of-squares for this channel ----
mov.f32 %sum, 0f00000000;
mov.f32 %sqsum, 0f00000000;
// Grid-stride loop over B*spatial for this channel
mov.u32 %idx, %tid;
PASS1_LOOP:
setp.ge.u32 %p, %idx, %tpc;
@%p bra PASS1_DONE;
// Linear offset = (idx / spatial) * channels * spatial + ch * spatial + idx % spatial
div.u32 %half, %idx, %sp;
rem.u32 %half, %idx, %sp; // reuse half as spatial_idx
// batch_offset = (idx / sp) * (n_ch * sp) + ch * sp + (idx % sp)
div.u32 %half, %idx, %sp; // batch_idx
mul.lo.u32 %half, %half, %n_ch;
add.u32 %half, %half, %ch;
mul.lo.u32 %half, %half, %sp;
rem.u32 %idx, %idx, %sp; // spatial_idx
add.u32 %half, %half, %idx;
cvt.u64.u32 %off64, %half;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %in, %off64;
ld.global.f32 %val, [%tmp64];
add.f32 %sum, %sum, %val;
fma.rn.f32 %sqsum, %val, %val, %sqsum;
// Restore idx for stride
// Recompute idx from tid + iteration * bdim
add.u32 %idx, %idx, %bdim; // This is wrong - need proper loop counter
bra PASS1_LOOP;
PASS1_DONE:
// Store to shared memory for block reduction
cvt.u64.u32 %off64, %tid;
shl.b64 %off64, %off64, 2;
st.shared.f32 [smem_sum + %off64], %sum;
st.shared.f32 [smem_sq + %off64], %sqsum;
bar.sync 0;
// Tree reduction
mov.u32 %half, 128;
REDUCE_LOOP:
setp.lt.u32 %p, %half, 1;
@%p bra REDUCE_DONE;
setp.ge.u32 %p, %tid, %half;
@%p bra REDUCE_SKIP;
add.u32 %idx, %tid, %half;
cvt.u64.u32 %off64, %idx;
shl.b64 %off64, %off64, 2;
ld.shared.f32 %other, [smem_sum + %off64];
cvt.u64.u32 %tmp64, %tid;
shl.b64 %tmp64, %tmp64, 2;
ld.shared.f32 %sum, [smem_sum + %tmp64];
add.f32 %sum, %sum, %other;
st.shared.f32 [smem_sum + %tmp64], %sum;
ld.shared.f32 %other, [smem_sq + %off64];
ld.shared.f32 %sqsum, [smem_sq + %tmp64];
add.f32 %sqsum, %sqsum, %other;
st.shared.f32 [smem_sq + %tmp64], %sqsum;
REDUCE_SKIP:
bar.sync 0;
shr.u32 %half, %half, 1;
bra REDUCE_LOOP;
REDUCE_DONE:
// Thread 0 computes mean and invstd
setp.ne.u32 %ptid0, %tid, 0;
@%ptid0 bra WAIT_STATS;
ld.shared.f32 %sum, [smem_sum];
ld.shared.f32 %sqsum, [smem_sq];
cvt.rn.f32.u32 %n_f, %tpc;
div.rn.f32 %mean, %sum, %n_f;
// var = sqsum/n - mean^2
div.rn.f32 %var, %sqsum, %n_f;
fma.rn.f32 %var, %mean, %mean, %var; // This adds mean^2, need to subtract
// Actually: var = E[x^2] - E[x]^2, so var = sqsum/n - mean^2
// We had: var = sqsum/n, now subtract mean^2
neg.f32 %other, %mean;
fma.rn.f32 %var, %other, %mean, %var; // var = var + (-mean)*mean = sqsum/n - mean^2
// invstd = 1/sqrt(var + eps)
add.f32 %other, %var, %eps_reg;
sqrt.rn.f32 %other, %other;
div.rn.f32 %invstd, %one, %other;
// Save mean and invstd
cvt.u64.u32 %off64, %ch;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %sm, %off64;
st.global.f32 [%tmp64], %mean;
add.u64 %tmp64, %si, %off64;
st.global.f32 [%tmp64], %invstd;
// Store to shared for other threads
st.shared.f32 [smem_sum], %mean;
st.shared.f32 [smem_sq], %invstd;
WAIT_STATS:
bar.sync 0;
// All threads read mean and invstd from shared
ld.shared.f32 %mean, [smem_sum];
ld.shared.f32 %invstd, [smem_sq];
// Load weight and bias for this channel
cvt.u64.u32 %off64, %ch;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %w, %off64;
ld.global.f32 %gamma, [%tmp64];
add.u64 %tmp64, %b, %off64;
ld.global.f32 %beta, [%tmp64];
// ---- Pass 2: normalize + affine ----
// For now this is a placeholder - the indexing needs to match pass 1
// Each thread normalizes its elements
END:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const MAXPOOL2D_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry maxpool2d_forward_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u32 batch,
.param .u32 channels,
.param .u32 h_in,
.param .u32 w_in,
.param .u32 h_out,
.param .u32 w_out,
.param .u32 kh,
.param .u32 kw,
.param .u32 sh,
.param .u32 sw,
.param .u32 ph,
.param .u32 pw,
.param .u32 total
) {
.reg .u32 %tid, %bid, %bdim, %gdim, %idx, %stride, %total_reg;
.reg .u32 %b_idx, %c_idx, %oh, %ow, %rem, %ih, %iw, %tmp;
.reg .u32 %i, %j, %h_in_reg, %w_in_reg, %kh_reg, %kw_reg;
.reg .u32 %sh_reg, %sw_reg, %ph_reg, %pw_reg, %h_out_reg, %w_out_reg;
.reg .u32 %batch_reg, %ch_reg;
.reg .u64 %in, %out, %off64, %tmp64;
.reg .f32 %max_val, %cur_val, %neg_inf;
.reg .pred %p, %p_bounds, %p_gt;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u32 %batch_reg, [batch];
ld.param.u32 %ch_reg, [channels];
ld.param.u32 %h_in_reg, [h_in];
ld.param.u32 %w_in_reg, [w_in];
ld.param.u32 %h_out_reg, [h_out];
ld.param.u32 %w_out_reg, [w_out];
ld.param.u32 %kh_reg, [kh];
ld.param.u32 %kw_reg, [kw];
ld.param.u32 %sh_reg, [sh];
ld.param.u32 %sw_reg, [sw];
ld.param.u32 %ph_reg, [ph];
ld.param.u32 %pw_reg, [pw];
ld.param.u32 %total_reg, [total];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %tid, %tid.x;
mov.u32 %gdim, %nctaid.x;
mad.lo.u32 %idx, %bid, %bdim, %tid;
mul.lo.u32 %stride, %bdim, %gdim;
// -inf for max initialization
mov.f32 %neg_inf, 0fFF800000;
LOOP:
setp.ge.u32 %p, %idx, %total_reg;
@%p bra END;
// Decompose idx into (b, c, oh, ow)
mov.u32 %rem, %idx;
div.u32 %b_idx, %rem, %ch_reg;
// Actually need: idx = b * C * H_out * W_out + c * H_out * W_out + oh * W_out + ow
// So decompose from the right:
rem.u32 %ow, %rem, %w_out_reg;
div.u32 %rem, %rem, %w_out_reg;
rem.u32 %oh, %rem, %h_out_reg;
div.u32 %rem, %rem, %h_out_reg;
rem.u32 %c_idx, %rem, %ch_reg;
div.u32 %b_idx, %rem, %ch_reg;
mov.f32 %max_val, %neg_inf;
// Slide the kernel window
mov.u32 %i, 0;
KH_LOOP:
setp.ge.u32 %p, %i, %kh_reg;
@%p bra KH_DONE;
mov.u32 %j, 0;
KW_LOOP:
setp.ge.u32 %p, %j, %kw_reg;
@%p bra KW_DONE;
// ih = oh * sh + i - ph, iw = ow * sw + j - pw
mad.lo.u32 %ih, %oh, %sh_reg, %i;
sub.u32 %ih, %ih, %ph_reg;
mad.lo.u32 %iw, %ow, %sw_reg, %j;
sub.u32 %iw, %iw, %pw_reg;
// Bounds check: 0 <= ih < h_in && 0 <= iw < w_in
// Since unsigned, just check < h_in and < w_in
setp.ge.u32 %p_bounds, %ih, %h_in_reg;
@%p_bounds bra KW_NEXT;
setp.ge.u32 %p_bounds, %iw, %w_in_reg;
@%p_bounds bra KW_NEXT;
// input_offset = b * C * H * W + c * H * W + ih * W + iw
mul.lo.u32 %tmp, %b_idx, %ch_reg;
add.u32 %tmp, %tmp, %c_idx;
mul.lo.u32 %tmp, %tmp, %h_in_reg;
add.u32 %tmp, %tmp, %ih;
mul.lo.u32 %tmp, %tmp, %w_in_reg;
add.u32 %tmp, %tmp, %iw;
cvt.u64.u32 %off64, %tmp;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %in, %off64;
ld.global.f32 %cur_val, [%tmp64];
max.f32 %max_val, %max_val, %cur_val;
KW_NEXT:
add.u32 %j, %j, 1;
bra KW_LOOP;
KW_DONE:
add.u32 %i, %i, 1;
bra KH_LOOP;
KH_DONE:
// Store output
cvt.u64.u32 %off64, %idx;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %out, %off64;
st.global.f32 [%tmp64], %max_val;
add.u32 %idx, %idx, %stride;
bra LOOP;
END:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const AVGPOOL2D_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry avgpool2d_forward_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u32 batch,
.param .u32 channels,
.param .u32 h_in,
.param .u32 w_in,
.param .u32 h_out,
.param .u32 w_out,
.param .u32 kh,
.param .u32 kw,
.param .u32 sh,
.param .u32 sw,
.param .u32 ph,
.param .u32 pw,
.param .u32 total
) {
.reg .u32 %tid, %bid, %bdim, %gdim, %idx, %stride, %total_reg;
.reg .u32 %b_idx, %c_idx, %oh, %ow, %rem, %ih, %iw, %tmp, %count;
.reg .u32 %i, %j, %h_in_reg, %w_in_reg, %kh_reg, %kw_reg;
.reg .u32 %sh_reg, %sw_reg, %ph_reg, %pw_reg, %h_out_reg, %w_out_reg;
.reg .u32 %batch_reg, %ch_reg;
.reg .u64 %in, %out, %off64, %tmp64;
.reg .f32 %sum_val, %cur_val, %count_f, %avg;
.reg .pred %p, %p_bounds;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u32 %batch_reg, [batch];
ld.param.u32 %ch_reg, [channels];
ld.param.u32 %h_in_reg, [h_in];
ld.param.u32 %w_in_reg, [w_in];
ld.param.u32 %h_out_reg, [h_out];
ld.param.u32 %w_out_reg, [w_out];
ld.param.u32 %kh_reg, [kh];
ld.param.u32 %kw_reg, [kw];
ld.param.u32 %sh_reg, [sh];
ld.param.u32 %sw_reg, [sw];
ld.param.u32 %ph_reg, [ph];
ld.param.u32 %pw_reg, [pw];
ld.param.u32 %total_reg, [total];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %tid, %tid.x;
mov.u32 %gdim, %nctaid.x;
mad.lo.u32 %idx, %bid, %bdim, %tid;
mul.lo.u32 %stride, %bdim, %gdim;
LOOP:
setp.ge.u32 %p, %idx, %total_reg;
@%p bra END;
// Decompose idx into (b, c, oh, ow) — same as MaxPool2d
mov.u32 %rem, %idx;
rem.u32 %ow, %rem, %w_out_reg;
div.u32 %rem, %rem, %w_out_reg;
rem.u32 %oh, %rem, %h_out_reg;
div.u32 %rem, %rem, %h_out_reg;
rem.u32 %c_idx, %rem, %ch_reg;
div.u32 %b_idx, %rem, %ch_reg;
mov.f32 %sum_val, 0f00000000;
mov.u32 %count, 0;
mov.u32 %i, 0;
AKH_LOOP:
setp.ge.u32 %p, %i, %kh_reg;
@%p bra AKH_DONE;
mov.u32 %j, 0;
AKW_LOOP:
setp.ge.u32 %p, %j, %kw_reg;
@%p bra AKW_DONE;
mad.lo.u32 %ih, %oh, %sh_reg, %i;
sub.u32 %ih, %ih, %ph_reg;
mad.lo.u32 %iw, %ow, %sw_reg, %j;
sub.u32 %iw, %iw, %pw_reg;
setp.ge.u32 %p_bounds, %ih, %h_in_reg;
@%p_bounds bra AKW_NEXT;
setp.ge.u32 %p_bounds, %iw, %w_in_reg;
@%p_bounds bra AKW_NEXT;
mul.lo.u32 %tmp, %b_idx, %ch_reg;
add.u32 %tmp, %tmp, %c_idx;
mul.lo.u32 %tmp, %tmp, %h_in_reg;
add.u32 %tmp, %tmp, %ih;
mul.lo.u32 %tmp, %tmp, %w_in_reg;
add.u32 %tmp, %tmp, %iw;
cvt.u64.u32 %off64, %tmp;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %in, %off64;
ld.global.f32 %cur_val, [%tmp64];
add.f32 %sum_val, %sum_val, %cur_val;
add.u32 %count, %count, 1;
AKW_NEXT:
add.u32 %j, %j, 1;
bra AKW_LOOP;
AKW_DONE:
add.u32 %i, %i, 1;
bra AKH_LOOP;
AKH_DONE:
// avg = sum / count (count_include_pad = false behavior)
cvt.rn.f32.u32 %count_f, %count;
div.rn.f32 %avg, %sum_val, %count_f;
cvt.u64.u32 %off64, %idx;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %out, %off64;
st.global.f32 [%tmp64], %avg;
add.u32 %idx, %idx, %stride;
bra LOOP;
END:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SOFTMAX_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.shared .align 4 .f32 sdata[256];\n\
\n\
.visible .entry softmax_kernel(\n\
.param .u64 input_ptr,\n\
.param .u64 output_ptr,\n\
.param .u32 rows,\n\
.param .u32 cols\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
.reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
.reg .f32 %val, %max_val, %sum_val, %exp_val, %result;\n\
.reg .pred %p, %loop_p;\n\
.reg .u32 %half, %other_tid;\n\
.reg .f32 %other_val;\n\
.reg .pred %reduce_p;\n\
\n\
ld.param.u64 %in, [input_ptr];\n\
ld.param.u64 %out, [output_ptr];\n\
ld.param.u32 %rows_reg, [rows];\n\
ld.param.u32 %cols_reg, [cols];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mov.u64 %sbase, sdata;\n\
\n\
setp.ge.u32 %p, %bid, %rows_reg;\n\
@%p bra DONE;\n\
\n\
cvt.u64.u32 %row_off, %bid;\n\
cvt.u64.u32 %off, %cols_reg;\n\
mul.lo.u64 %row_off, %row_off, %off;\n\
shl.b64 %row_off, %row_off, 2;\n\
\n\
mov.f32 %max_val, 0fFF800000;\n\
mov.u32 %j, %r_tid;\n\
FIND_MAX:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra FIND_MAX_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %off, %in, %off;\n\
add.u64 %off, %off, %row_off;\n\
ld.global.f32 %val, [%off];\n\
max.f32 %max_val, %max_val, %val;\n\
add.u32 %j, %j, %bdim;\n\
bra FIND_MAX;\n\
FIND_MAX_DONE:\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f32 [%saddr], %max_val;\n\
bar.sync 0;\n\
\n\
mov.u32 %half, %bdim;\n\
MAX_REDUCE:\n\
shr.u32 %half, %half, 1;\n\
setp.eq.u32 %reduce_p, %half, 0;\n\
@%reduce_p bra MAX_REDUCE_DONE;\n\
setp.ge.u32 %reduce_p, %r_tid, %half;\n\
@%reduce_p bra MAX_REDUCE_SKIP;\n\
add.u32 %other_tid, %r_tid, %half;\n\
cvt.u64.u32 %off, %other_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %other_val, [%saddr];\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f32 %max_val, [%saddr];\n\
max.f32 %max_val, %max_val, %other_val;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f32 [%saddr], %max_val;\n\
MAX_REDUCE_SKIP:\n\
bar.sync 0;\n\
bra MAX_REDUCE;\n\
MAX_REDUCE_DONE:\n\
\n\
ld.shared.f32 %max_val, [sdata];\n\
bar.sync 0;\n\
\n\
mov.f32 %sum_val, 0f00000000;\n\
mov.u32 %j, %r_tid;\n\
SUM_EXP:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra SUM_EXP_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %off, %in, %off;\n\
add.u64 %off, %off, %row_off;\n\
ld.global.f32 %val, [%off];\n\
sub.f32 %val, %val, %max_val;\n\
mul.f32 %val, %val, 0f3FB8AA3B;\n\
ex2.approx.f32 %exp_val, %val;\n\
add.f32 %sum_val, %sum_val, %exp_val;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %off, %out, %off;\n\
add.u64 %off, %off, %row_off;\n\
st.global.f32 [%off], %exp_val;\n\
add.u32 %j, %j, %bdim;\n\
bra SUM_EXP;\n\
SUM_EXP_DONE:\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f32 [%saddr], %sum_val;\n\
bar.sync 0;\n\
\n\
mov.u32 %half, %bdim;\n\
SUM_REDUCE:\n\
shr.u32 %half, %half, 1;\n\
setp.eq.u32 %reduce_p, %half, 0;\n\
@%reduce_p bra SUM_REDUCE_DONE;\n\
setp.ge.u32 %reduce_p, %r_tid, %half;\n\
@%reduce_p bra SUM_REDUCE_SKIP;\n\
add.u32 %other_tid, %r_tid, %half;\n\
cvt.u64.u32 %off, %other_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %other_val, [%saddr];\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f32 %sum_val, [%saddr];\n\
add.f32 %sum_val, %sum_val, %other_val;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f32 [%saddr], %sum_val;\n\
SUM_REDUCE_SKIP:\n\
bar.sync 0;\n\
bra SUM_REDUCE;\n\
SUM_REDUCE_DONE:\n\
\n\
ld.shared.f32 %sum_val, [sdata];\n\
bar.sync 0;\n\
\n\
rcp.approx.f32 %sum_val, %sum_val;\n\
mov.u32 %j, %r_tid;\n\
NORMALIZE:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra NORMALIZE_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %off, %out, %off;\n\
add.u64 %off, %off, %row_off;\n\
ld.global.f32 %val, [%off];\n\
mul.f32 %result, %val, %sum_val;\n\
st.global.f32 [%off], %result;\n\
add.u32 %j, %j, %bdim;\n\
bra NORMALIZE;\n\
NORMALIZE_DONE:\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const SOFTMAX_F64_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.shared .align 8 .f64 sdata[256];\n\
\n\
.visible .entry softmax_f64_kernel(\n\
.param .u64 input_ptr,\n\
.param .u64 output_ptr,\n\
.param .u32 rows,\n\
.param .u32 cols\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
.reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
.reg .f64 %val, %max_val, %sum_val, %exp_val, %result, %one;\n\
.reg .pred %p, %loop_p;\n\
.reg .u32 %half, %other_tid;\n\
.reg .f64 %other_val;\n\
.reg .pred %reduce_p;\n\
.reg .f64 %e_nf, %e_r, %e_p, %e_half, %e_one;\n\
.reg .s32 %e_ni;\n\
.reg .s64 %e_ni64, %e_bits;\n\
\n\
ld.param.u64 %in, [input_ptr];\n\
ld.param.u64 %out, [output_ptr];\n\
ld.param.u32 %rows_reg, [rows];\n\
ld.param.u32 %cols_reg, [cols];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mov.u64 %sbase, sdata;\n\
mov.f64 %one, 0d3FF0000000000000;\n\
\n\
setp.ge.u32 %p, %bid, %rows_reg;\n\
@%p bra DONE;\n\
\n\
cvt.u64.u32 %row_off, %bid;\n\
cvt.u64.u32 %off, %cols_reg;\n\
mul.lo.u64 %row_off, %row_off, %off;\n\
shl.b64 %row_off, %row_off, 3;\n\
\n\
mov.f64 %max_val, 0dFFF0000000000000;\n\
mov.u32 %j, %r_tid;\n\
FIND_MAX:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra FIND_MAX_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %off, %in, %off;\n\
add.u64 %off, %off, %row_off;\n\
ld.global.f64 %val, [%off];\n\
max.f64 %max_val, %max_val, %val;\n\
add.u32 %j, %j, %bdim;\n\
bra FIND_MAX;\n\
FIND_MAX_DONE:\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f64 [%saddr], %max_val;\n\
bar.sync 0;\n\
\n\
mov.u32 %half, %bdim;\n\
MAX_REDUCE:\n\
shr.u32 %half, %half, 1;\n\
setp.eq.u32 %reduce_p, %half, 0;\n\
@%reduce_p bra MAX_REDUCE_DONE;\n\
setp.ge.u32 %reduce_p, %r_tid, %half;\n\
@%reduce_p bra MAX_REDUCE_SKIP;\n\
add.u32 %other_tid, %r_tid, %half;\n\
cvt.u64.u32 %off, %other_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f64 %other_val, [%saddr];\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f64 %max_val, [%saddr];\n\
max.f64 %max_val, %max_val, %other_val;\n\
st.shared.f64 [%saddr], %max_val;\n\
MAX_REDUCE_SKIP:\n\
bar.sync 0;\n\
bra MAX_REDUCE;\n\
MAX_REDUCE_DONE:\n\
\n\
ld.shared.f64 %max_val, [sdata];\n\
bar.sync 0;\n\
\n\
mov.f64 %sum_val, 0d0000000000000000;\n\
mov.u32 %j, %r_tid;\n\
SUM_EXP:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra SUM_EXP_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %off, %in, %off;\n\
add.u64 %off, %off, %row_off;\n\
ld.global.f64 %val, [%off];\n\
sub.f64 %val, %val, %max_val;\n\
mov.f64 %e_one, 0d3FF0000000000000;\n\
mov.f64 %e_half, 0d3FE0000000000000;\n\
mul.f64 %e_nf, %val, 0d3FF71547652B82FE;\n\
cvt.rni.f64.f64 %e_nf, %e_nf;\n\
cvt.rni.s32.f64 %e_ni, %e_nf;\n\
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %val;\n\
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;\n\
mov.f64 %e_p, 0d3E21EED8EFF8D898;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;\n\
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;\n\
fma.rn.f64 %e_p, %e_p, %e_r, %e_one;\n\
fma.rn.f64 %exp_val, %e_p, %e_r, %e_one;\n\
cvt.s64.s32 %e_ni64, %e_ni;\n\
add.s64 %e_ni64, %e_ni64, 1023;\n\
shl.b64 %e_bits, %e_ni64, 52;\n\
mov.b64 %e_nf, %e_bits;\n\
mul.f64 %exp_val, %exp_val, %e_nf;\n\
add.f64 %sum_val, %sum_val, %exp_val;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %off, %out, %off;\n\
add.u64 %off, %off, %row_off;\n\
st.global.f64 [%off], %exp_val;\n\
add.u32 %j, %j, %bdim;\n\
bra SUM_EXP;\n\
SUM_EXP_DONE:\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f64 [%saddr], %sum_val;\n\
bar.sync 0;\n\
\n\
mov.u32 %half, %bdim;\n\
SUM_REDUCE:\n\
shr.u32 %half, %half, 1;\n\
setp.eq.u32 %reduce_p, %half, 0;\n\
@%reduce_p bra SUM_REDUCE_DONE;\n\
setp.ge.u32 %reduce_p, %r_tid, %half;\n\
@%reduce_p bra SUM_REDUCE_SKIP;\n\
add.u32 %other_tid, %r_tid, %half;\n\
cvt.u64.u32 %off, %other_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f64 %other_val, [%saddr];\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f64 %sum_val, [%saddr];\n\
add.f64 %sum_val, %sum_val, %other_val;\n\
st.shared.f64 [%saddr], %sum_val;\n\
SUM_REDUCE_SKIP:\n\
bar.sync 0;\n\
bra SUM_REDUCE;\n\
SUM_REDUCE_DONE:\n\
\n\
ld.shared.f64 %sum_val, [sdata];\n\
bar.sync 0;\n\
\n\
div.rn.f64 %sum_val, %one, %sum_val;\n\
mov.u32 %j, %r_tid;\n\
NORMALIZE:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra NORMALIZE_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %off, %out, %off;\n\
add.u64 %off, %off, %row_off;\n\
ld.global.f64 %val, [%off];\n\
mul.f64 %result, %val, %sum_val;\n\
st.global.f64 [%off], %result;\n\
add.u32 %j, %j, %bdim;\n\
bra NORMALIZE;\n\
NORMALIZE_DONE:\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const DROPOUT_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.visible .entry dropout_kernel(\n\
.param .u64 input_ptr,\n\
.param .u64 output_ptr,\n\
.param .u32 n,\n\
.param .u32 threshold,\n\
.param .f32 scale,\n\
.param .u32 seed\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %thresh, %seed_reg, %rng, %tmp;\n\
.reg .u64 %in, %out, %off;\n\
.reg .f32 %val, %scale_reg, %zero;\n\
.reg .pred %p, %drop_p;\n\
\n\
ld.param.u64 %in, [input_ptr];\n\
ld.param.u64 %out, [output_ptr];\n\
ld.param.u32 %n_reg, [n];\n\
ld.param.u32 %thresh, [threshold];\n\
ld.param.f32 %scale_reg, [scale];\n\
ld.param.u32 %seed_reg, [seed];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
\n\
setp.ge.u32 %p, %r_tid, %n_reg;\n\
@%p bra DONE;\n\
\n\
mul.lo.u32 %rng, %r_tid, 2654435761;\n\
xor.b32 %rng, %rng, %seed_reg;\n\
shl.b32 %tmp, %rng, 13;\n\
xor.b32 %rng, %rng, %tmp;\n\
shr.b32 %tmp, %rng, 17;\n\
xor.b32 %rng, %rng, %tmp;\n\
shl.b32 %tmp, %rng, 5;\n\
xor.b32 %rng, %rng, %tmp;\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %in, %in, %off;\n\
add.u64 %out, %out, %off;\n\
ld.global.f32 %val, [%in];\n\
\n\
setp.lo.u32 %drop_p, %rng, %thresh;\n\
mov.f32 %zero, 0f00000000;\n\
@%drop_p mov.f32 %val, %zero;\n\
@!%drop_p mul.f32 %val, %val, %scale_reg;\n\
\n\
st.global.f32 [%out], %val;\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const BROADCAST_ADD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry broadcast_add_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u64 a_strides_ptr,
.param .u64 b_strides_ptr,
.param .u64 out_shape_ptr,
.param .u32 n,
.param .u32 ndim
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
.reg .u32 %remaining, %a_idx, %b_idx, %d;
.reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
.reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
.reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
.reg .f32 %va, %vb, %vr;
.reg .pred %p, %loop_p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u64 %a_str, [a_strides_ptr];
ld.param.u64 %b_str, [b_strides_ptr];
ld.param.u64 %oshape, [out_shape_ptr];
ld.param.u32 %n_reg, [n];
ld.param.u32 %ndim_reg, [ndim];
// Global thread index.
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
// Decompose flat index into N-d coordinates and compute A/B indices.
mov.u32 %remaining, %r_tid;
mov.u32 %a_idx, 0;
mov.u32 %b_idx, 0;
mov.u32 %d, %ndim_reg;
LOOP:
setp.eq.u32 %loop_p, %d, 0;
@%loop_p bra END_LOOP;
sub.u32 %d, %d, 1;
// Byte offset for dimension d: d * 4.
cvt.u64.u32 %d64, %d;
shl.b64 %d64, %d64, 2;
// Load out_shape[d].
add.u64 %tmp, %oshape, %d64;
ld.global.u32 %shape_d, [%tmp];
// Load a_strides[d] and b_strides[d].
add.u64 %tmp, %a_str, %d64;
ld.global.u32 %a_str_d, [%tmp];
add.u64 %tmp, %b_str, %d64;
ld.global.u32 %b_str_d, [%tmp];
// coord = remaining % shape_d; remaining /= shape_d.
rem.u32 %coord, %remaining, %shape_d;
div.u32 %remaining, %remaining, %shape_d;
// a_idx += coord * a_stride[d]; b_idx += coord * b_stride[d].
mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
bra LOOP;
END_LOOP:
// Load a[a_idx] and b[b_idx] (f32 = 4 bytes).
cvt.u64.u32 %off_a, %a_idx;
shl.b64 %off_a, %off_a, 2;
add.u64 %off_a, %a, %off_a;
ld.global.f32 %va, [%off_a];
cvt.u64.u32 %off_b, %b_idx;
shl.b64 %off_b, %off_b, 2;
add.u64 %off_b, %b, %off_b;
ld.global.f32 %vb, [%off_b];
// Operation: add.
add.f32 %vr, %va, %vb;
// Store to out[tid].
cvt.u64.u32 %off_out, %r_tid;
shl.b64 %off_out, %off_out, 2;
add.u64 %off_out, %out, %off_out;
st.global.f32 [%off_out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const BROADCAST_SUB_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry broadcast_sub_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u64 a_strides_ptr,
.param .u64 b_strides_ptr,
.param .u64 out_shape_ptr,
.param .u32 n,
.param .u32 ndim
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
.reg .u32 %remaining, %a_idx, %b_idx, %d;
.reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
.reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
.reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
.reg .f32 %va, %vb, %vr;
.reg .pred %p, %loop_p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u64 %a_str, [a_strides_ptr];
ld.param.u64 %b_str, [b_strides_ptr];
ld.param.u64 %oshape, [out_shape_ptr];
ld.param.u32 %n_reg, [n];
ld.param.u32 %ndim_reg, [ndim];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
mov.u32 %remaining, %r_tid;
mov.u32 %a_idx, 0;
mov.u32 %b_idx, 0;
mov.u32 %d, %ndim_reg;
LOOP:
setp.eq.u32 %loop_p, %d, 0;
@%loop_p bra END_LOOP;
sub.u32 %d, %d, 1;
cvt.u64.u32 %d64, %d;
shl.b64 %d64, %d64, 2;
add.u64 %tmp, %oshape, %d64;
ld.global.u32 %shape_d, [%tmp];
add.u64 %tmp, %a_str, %d64;
ld.global.u32 %a_str_d, [%tmp];
add.u64 %tmp, %b_str, %d64;
ld.global.u32 %b_str_d, [%tmp];
rem.u32 %coord, %remaining, %shape_d;
div.u32 %remaining, %remaining, %shape_d;
mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
bra LOOP;
END_LOOP:
cvt.u64.u32 %off_a, %a_idx;
shl.b64 %off_a, %off_a, 2;
add.u64 %off_a, %a, %off_a;
ld.global.f32 %va, [%off_a];
cvt.u64.u32 %off_b, %b_idx;
shl.b64 %off_b, %off_b, 2;
add.u64 %off_b, %b, %off_b;
ld.global.f32 %vb, [%off_b];
sub.f32 %vr, %va, %vb;
cvt.u64.u32 %off_out, %r_tid;
shl.b64 %off_out, %off_out, 2;
add.u64 %off_out, %out, %off_out;
st.global.f32 [%off_out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const BROADCAST_MUL_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry broadcast_mul_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u64 a_strides_ptr,
.param .u64 b_strides_ptr,
.param .u64 out_shape_ptr,
.param .u32 n,
.param .u32 ndim
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
.reg .u32 %remaining, %a_idx, %b_idx, %d;
.reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
.reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
.reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
.reg .f32 %va, %vb, %vr;
.reg .pred %p, %loop_p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u64 %a_str, [a_strides_ptr];
ld.param.u64 %b_str, [b_strides_ptr];
ld.param.u64 %oshape, [out_shape_ptr];
ld.param.u32 %n_reg, [n];
ld.param.u32 %ndim_reg, [ndim];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
mov.u32 %remaining, %r_tid;
mov.u32 %a_idx, 0;
mov.u32 %b_idx, 0;
mov.u32 %d, %ndim_reg;
LOOP:
setp.eq.u32 %loop_p, %d, 0;
@%loop_p bra END_LOOP;
sub.u32 %d, %d, 1;
cvt.u64.u32 %d64, %d;
shl.b64 %d64, %d64, 2;
add.u64 %tmp, %oshape, %d64;
ld.global.u32 %shape_d, [%tmp];
add.u64 %tmp, %a_str, %d64;
ld.global.u32 %a_str_d, [%tmp];
add.u64 %tmp, %b_str, %d64;
ld.global.u32 %b_str_d, [%tmp];
rem.u32 %coord, %remaining, %shape_d;
div.u32 %remaining, %remaining, %shape_d;
mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
bra LOOP;
END_LOOP:
cvt.u64.u32 %off_a, %a_idx;
shl.b64 %off_a, %off_a, 2;
add.u64 %off_a, %a, %off_a;
ld.global.f32 %va, [%off_a];
cvt.u64.u32 %off_b, %b_idx;
shl.b64 %off_b, %off_b, 2;
add.u64 %off_b, %b, %off_b;
ld.global.f32 %vb, [%off_b];
mul.f32 %vr, %va, %vb;
cvt.u64.u32 %off_out, %r_tid;
shl.b64 %off_out, %off_out, 2;
add.u64 %off_out, %out, %off_out;
st.global.f32 [%off_out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const BROADCAST_DIV_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry broadcast_div_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u64 a_strides_ptr,
.param .u64 b_strides_ptr,
.param .u64 out_shape_ptr,
.param .u32 n,
.param .u32 ndim
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
.reg .u32 %remaining, %a_idx, %b_idx, %d;
.reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
.reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
.reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
.reg .f32 %va, %vb, %vr;
.reg .pred %p, %loop_p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u64 %a_str, [a_strides_ptr];
ld.param.u64 %b_str, [b_strides_ptr];
ld.param.u64 %oshape, [out_shape_ptr];
ld.param.u32 %n_reg, [n];
ld.param.u32 %ndim_reg, [ndim];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
mov.u32 %remaining, %r_tid;
mov.u32 %a_idx, 0;
mov.u32 %b_idx, 0;
mov.u32 %d, %ndim_reg;
LOOP:
setp.eq.u32 %loop_p, %d, 0;
@%loop_p bra END_LOOP;
sub.u32 %d, %d, 1;
cvt.u64.u32 %d64, %d;
shl.b64 %d64, %d64, 2;
add.u64 %tmp, %oshape, %d64;
ld.global.u32 %shape_d, [%tmp];
add.u64 %tmp, %a_str, %d64;
ld.global.u32 %a_str_d, [%tmp];
add.u64 %tmp, %b_str, %d64;
ld.global.u32 %b_str_d, [%tmp];
rem.u32 %coord, %remaining, %shape_d;
div.u32 %remaining, %remaining, %shape_d;
mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
bra LOOP;
END_LOOP:
cvt.u64.u32 %off_a, %a_idx;
shl.b64 %off_a, %off_a, 2;
add.u64 %off_a, %a, %off_a;
ld.global.f32 %va, [%off_a];
cvt.u64.u32 %off_b, %b_idx;
shl.b64 %off_b, %off_b, 2;
add.u64 %off_b, %b, %off_b;
ld.global.f32 %vb, [%off_b];
div.f32 %vr, %va, %vb;
cvt.u64.u32 %off_out, %r_tid;
shl.b64 %off_out, %off_out, 2;
add.u64 %off_out, %out, %off_out;
st.global.f32 [%off_out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const STRIDED_SPLIT_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry strided_split_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u32 total_along_axis,
.param .u32 split_offset,
.param .u32 split_size,
.param .u32 inner_size,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u32 %total_ax, %sp_off, %sp_sz, %inner_sz;
.reg .u32 %outer_idx, %within, %chunk_stride, %src_idx, %base_off, %tmp;
.reg .u64 %in, %out, %off;
.reg .f32 %val;
.reg .pred %p;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u32 %total_ax, [total_along_axis];
ld.param.u32 %sp_off, [split_offset];
ld.param.u32 %sp_sz, [split_size];
ld.param.u32 %inner_sz, [inner_size];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
// chunk_stride = split_size * inner_size
mul.lo.u32 %chunk_stride, %sp_sz, %inner_sz;
// outer_idx = r_tid / chunk_stride
div.u32 %outer_idx, %r_tid, %chunk_stride;
// within = r_tid % chunk_stride
rem.u32 %within, %r_tid, %chunk_stride;
// base_off = split_offset * inner_size
mul.lo.u32 %base_off, %sp_off, %inner_sz;
// src_idx = outer_idx * total_along_axis * inner_size + base_off + within
mul.lo.u32 %src_idx, %outer_idx, %total_ax;
mul.lo.u32 %src_idx, %src_idx, %inner_sz;
add.u32 %src_idx, %src_idx, %base_off;
add.u32 %src_idx, %src_idx, %within;
// Load from in[src_idx]
cvt.u64.u32 %off, %src_idx;
shl.b64 %off, %off, 2;
add.u64 %off, %in, %off;
ld.global.f32 %val, [%off];
// Store to out[r_tid]
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %off, %out, %off;
st.global.f32 [%off], %val;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const STRIDED_CAT_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry strided_cat_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u32 total_along_axis,
.param .u32 cat_offset,
.param .u32 part_size,
.param .u32 inner_size,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u32 %total_ax, %cat_off, %part_sz, %inner_sz;
.reg .u32 %outer_idx, %within, %chunk_stride, %dst_idx, %base_off;
.reg .u64 %in, %out, %off;
.reg .f32 %val;
.reg .pred %p;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u32 %total_ax, [total_along_axis];
ld.param.u32 %cat_off, [cat_offset];
ld.param.u32 %part_sz, [part_size];
ld.param.u32 %inner_sz, [inner_size];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
// chunk_stride = part_size * inner_size
mul.lo.u32 %chunk_stride, %part_sz, %inner_sz;
// outer_idx = r_tid / chunk_stride
div.u32 %outer_idx, %r_tid, %chunk_stride;
// within = r_tid % chunk_stride
rem.u32 %within, %r_tid, %chunk_stride;
// base_off = cat_offset * inner_size
mul.lo.u32 %base_off, %cat_off, %inner_sz;
// dst_idx = outer_idx * total_along_axis * inner_size + base_off + within
mul.lo.u32 %dst_idx, %outer_idx, %total_ax;
mul.lo.u32 %dst_idx, %dst_idx, %inner_sz;
add.u32 %dst_idx, %dst_idx, %base_off;
add.u32 %dst_idx, %dst_idx, %within;
// Load from in[r_tid]
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %off, %in, %off;
ld.global.f32 %val, [%off];
// Store to out[dst_idx]
cvt.u64.u32 %off, %dst_idx;
shl.b64 %off, %off, 2;
add.u64 %off, %out, %off;
st.global.f32 [%off], %val;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const STRIDED_COPY_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry strided_copy_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u32 src_offset_base,
.param .u32 n,
.param .u32 os0, .param .u32 os1, .param .u32 os2, .param .u32 os3,
.param .u32 os4, .param .u32 os5, .param .u32 os6, .param .u32 os7,
.param .u32 ss0, .param .u32 ss1, .param .u32 ss2, .param .u32 ss3,
.param .u32 ss4, .param .u32 ss5, .param .u32 ss6, .param .u32 ss7
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u32 %flat, %src_idx, %coord, %tmp, %os, %ss;
.reg .u64 %in, %out, %off;
.reg .f32 %val;
.reg .pred %p;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u32 %src_idx, [src_offset_base];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
mov.u32 %flat, %r_tid;
// Dim 0
ld.param.u32 %os, [os0];
ld.param.u32 %ss, [ss0];
div.u32 %coord, %flat, %os;
mul.lo.u32 %tmp, %coord, %os;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ss;
add.u32 %src_idx, %src_idx, %tmp;
// Dim 1
ld.param.u32 %os, [os1];
ld.param.u32 %ss, [ss1];
div.u32 %coord, %flat, %os;
mul.lo.u32 %tmp, %coord, %os;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ss;
add.u32 %src_idx, %src_idx, %tmp;
// Dim 2
ld.param.u32 %os, [os2];
ld.param.u32 %ss, [ss2];
div.u32 %coord, %flat, %os;
mul.lo.u32 %tmp, %coord, %os;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ss;
add.u32 %src_idx, %src_idx, %tmp;
// Dim 3
ld.param.u32 %os, [os3];
ld.param.u32 %ss, [ss3];
div.u32 %coord, %flat, %os;
mul.lo.u32 %tmp, %coord, %os;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ss;
add.u32 %src_idx, %src_idx, %tmp;
// Dim 4
ld.param.u32 %os, [os4];
ld.param.u32 %ss, [ss4];
div.u32 %coord, %flat, %os;
mul.lo.u32 %tmp, %coord, %os;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ss;
add.u32 %src_idx, %src_idx, %tmp;
// Dim 5
ld.param.u32 %os, [os5];
ld.param.u32 %ss, [ss5];
div.u32 %coord, %flat, %os;
mul.lo.u32 %tmp, %coord, %os;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ss;
add.u32 %src_idx, %src_idx, %tmp;
// Dim 6
ld.param.u32 %os, [os6];
ld.param.u32 %ss, [ss6];
div.u32 %coord, %flat, %os;
mul.lo.u32 %tmp, %coord, %os;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ss;
add.u32 %src_idx, %src_idx, %tmp;
// Dim 7
ld.param.u32 %os, [os7];
ld.param.u32 %ss, [ss7];
div.u32 %coord, %flat, %os;
mul.lo.u32 %tmp, %coord, %os;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ss;
add.u32 %src_idx, %src_idx, %tmp;
// Load from in[src_idx]
cvt.u64.u32 %off, %src_idx;
shl.b64 %off, %off, 2;
add.u64 %off, %in, %off;
ld.global.f32 %val, [%off];
// Store to out[r_tid]
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %off, %out, %off;
st.global.f32 [%off], %val;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const DIV_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry div_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %b, %out, %off;
.reg .f32 %va, %vb, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
ld.global.f32 %vb, [%b];
div.rn.f32 %vr, %va, %vb;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const EXP_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry exp_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %va, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
// PTX ex2.approx computes 2^x; use the identity exp(x) = 2^(x * log2(e))
// log2(e) = 1.4426950408889634
mul.f32 %va, %va, 0f3FB8AA3B;
ex2.approx.f32 %vr, %va;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const EXP_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry exp_f64_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f64 %x, %vr;
.reg .f64 %log2e, %nf, %r;
.reg .f64 %p, %one, %half;
.reg .s32 %ni;
.reg .s64 %ni64, %exp_bits;
.reg .pred %p_bounds, %p_tid;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p_tid, %r_tid, %n_reg;
@%p_tid bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f64 %x, [%a];
// Constants
mov.f64 %log2e, 0d3FF71547652B82FE; // log2(e) = 1.4426950408889634
mov.f64 %ln2_hi, 0d3FE62E42FEFA3800; // ln(2) high bits
mov.f64 %ln2_lo, 0d3D2EF35793C76730; // ln(2) low bits
mov.f64 %one, 0d3FF0000000000000; // 1.0
mov.f64 %half, 0d3FE0000000000000; // 0.5
// n = round(x * log2(e))
mul.f64 %nf, %x, %log2e;
cvt.rni.f64.f64 %nf, %nf; // round to nearest integer
cvt.rni.s32.f64 %ni, %nf; // integer n
// r = x - n * ln2 (Cody-Waite two-step for precision)
fma.rn.f64 %r, %nf, 0dBFE62E42FEFA3800, %x; // r = x - n*ln2_hi
fma.rn.f64 %r, %nf, 0dBD2EF35793C76730, %r; // r -= n*ln2_lo
// Horner polynomial for exp(r) - 1 - r = r^2 * (1/2! + r*(1/3! + r*(1/4! + ...)))
// p starts at 1/11!, accumulates down to 1/2!
mov.f64 %p, 0d3E21EED8EFF8D898; // 1/11! = 2.505e-8
fma.rn.f64 %p, %p, %r, 0d3E5AE64567F544E4; // 1/10! = 2.756e-7
fma.rn.f64 %p, %p, %r, 0d3E927E4FB7789F5C; // 1/9! = 2.756e-6
fma.rn.f64 %p, %p, %r, 0d3EC71DE3A556C734; // 1/8! = 2.480e-5
fma.rn.f64 %p, %p, %r, 0d3EFA01A01A01A01A; // 1/7! = 1.984e-4
fma.rn.f64 %p, %p, %r, 0d3F2A01A01A01A01A; // 1/6! = 1.389e-3
fma.rn.f64 %p, %p, %r, 0d3F56C16C16C16C17; // 1/5! = 8.333e-3
fma.rn.f64 %p, %p, %r, 0d3F811111111111111; // 1/4! = 4.167e-2
fma.rn.f64 %p, %p, %r, 0d3FC5555555555555; // 1/3! = 1.667e-1
fma.rn.f64 %p, %p, %r, %half; // 1/2! = 5.000e-1
// exp(r) = 1 + r + r^2 * p => 1 + r*(1 + r*p)
fma.rn.f64 %p, %p, %r, %one; // p = r*p + 1
fma.rn.f64 %vr, %p, %r, %one; // vr = p*r + 1 = exp(r)
// Scale by 2^n: multiply by constructing the f64 bit pattern for 2^n.
// IEEE 754 f64: 2^n has exponent field = n + 1023, no mantissa bits.
// Bit pattern: (n + 1023) << 52.
cvt.s64.s32 %ni64, %ni;
add.s64 %ni64, %ni64, 1023;
shl.b64 %exp_bits, %ni64, 52;
mov.b64 %nf, %exp_bits; // reinterpret as f64 = 2^n
mul.f64 %vr, %vr, %nf;
st.global.f64 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const LOG_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry log_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %va, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
// PTX lg2.approx computes log2(x); use the identity ln(x) = log2(x) / log2(e)
// 1/log2(e) = ln(2) = 0.6931471805599453
lg2.approx.f32 %vr, %va;
mul.f32 %vr, %vr, 0f3F317218;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const LOG_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry log_f64_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .u64 %xbits, %mantissa_bits, %bias_bits;
.reg .f64 %x, %vr, %m, %f, %f2, %s, %p;
.reg .f64 %ln2_hi, %ln2_lo, %one, %two;
.reg .s32 %exp_i;
.reg .s64 %exp64;
.reg .f64 %nf;
.reg .pred %p_tid;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p_tid, %r_tid, %n_reg;
@%p_tid bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f64 %x, [%a];
mov.f64 %ln2_hi, 0d3FE62E42FEFA39EF; // ln(2) = 0.6931471805599453
mov.f64 %one, 0d3FF0000000000000;
mov.f64 %two, 0d4000000000000000;
// Extract exponent: n = exponent_field - 1023
mov.b64 %xbits, %x;
shr.u64 %exp64, %xbits, 52;
and.b64 %exp64, %exp64, 2047; // 11-bit exponent field
sub.s64 %exp64, %exp64, 1023;
cvt.rn.f64.s64 %nf, %exp64; // n as f64
// Extract mantissa m: set exponent to 1023 (so m is in [1, 2))
mov.u64 %bias_bits, 0x3FF0000000000000; // exponent = 1023
and.b64 %mantissa_bits, %xbits, 0x000FFFFFFFFFFFFF; // mantissa bits
or.b64 %mantissa_bits, %mantissa_bits, %bias_bits;
mov.b64 %m, %mantissa_bits; // m in [1.0, 2.0)
// f = (m - 1) / (m + 1) — maps [1,2) to [0, 1/3)
sub.f64 %f, %m, %one;
add.f64 %s, %m, %one;
div.rn.f64 %f, %f, %s;
// ln(m) = 2*f + 2*f^3/3 + 2*f^5/5 + 2*f^7/7 + 2*f^9/9 + 2*f^11/11
// Horner: ln(m) = 2*f*(1 + f^2*(1/3 + f^2*(1/5 + f^2*(1/7 + f^2*(1/9 + f^2/11)))))
mul.f64 %f2, %f, %f;
// p = 1/11
mov.f64 %p, 0d3FB745D1745D1746;
// p = p*f2 + 1/9
fma.rn.f64 %p, %p, %f2, 0d3FC1C71C71C71C72;
// p = p*f2 + 1/7
fma.rn.f64 %p, %p, %f2, 0d3FC2492492492492;
// p = p*f2 + 1/5
fma.rn.f64 %p, %p, %f2, 0d3FC999999999999A;
// p = p*f2 + 1/3
fma.rn.f64 %p, %p, %f2, 0d3FD5555555555555;
// p = p*f2 + 1
fma.rn.f64 %p, %p, %f2, %one;
// ln(m) = 2*f*p
mul.f64 %p, %p, %f;
add.f64 %p, %p, %p; // * 2
// ln(x) = n*ln(2) + ln(m)
fma.rn.f64 %vr, %nf, %ln2_hi, %p;
st.global.f64 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SQRT_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry sqrt_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %va, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
sqrt.rn.f32 %vr, %va;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const POW_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry pow_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .f32 exponent,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %va, %vr, %exp, %lg;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.f32 %exp, [exponent];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
// x^e = 2^(e * log2(x))
lg2.approx.f32 %lg, %va;
mul.f32 %lg, %lg, %exp;
ex2.approx.f32 %vr, %lg;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const POW_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry pow_f64_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .f64 exponent,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f64 %va, %vr, %exp64, %one, %two;
// log registers
.reg .u64 %l_xbits, %l_mbits, %l_bias;
.reg .s64 %l_exp64;
.reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf, %l_ln2, %l_lnx;
// exp registers
.reg .f64 %e_z, %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.f64 %exp64, [exponent];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f64 %va, [%a];
mov.f64 %one, 0d3FF0000000000000;
mov.f64 %two, 0d4000000000000000;
// === ln(va) via argument reduction ===
// Decompose va = 2^n * m, m in [1,2), ln(va) = n*ln(2) + ln(m)
mov.b64 %l_xbits, %va;
shr.u64 %l_exp64, %l_xbits, 52;
and.b64 %l_exp64, %l_exp64, 2047;
sub.s64 %l_exp64, %l_exp64, 1023;
cvt.rn.f64.s64 %l_nf, %l_exp64;
mov.u64 %l_bias, 0x3FF0000000000000;
and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;
or.b64 %l_mbits, %l_mbits, %l_bias;
mov.b64 %l_m, %l_mbits;
// f = (m-1)/(m+1)
sub.f64 %l_f, %l_m, %one;
add.f64 %l_s, %l_m, %one;
div.rn.f64 %l_f, %l_f, %l_s;
mul.f64 %l_f2, %l_f, %l_f;
// Horner: p = 1/11 + f2*(1/9 + f2*(1/7 + f2*(1/5 + f2*(1/3 + f2*1))))
mov.f64 %l_p, 0d3FB745D1745D1746;
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC1C71C71C71C72;
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492;
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A;
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555;
fma.rn.f64 %l_p, %l_p, %l_f2, %one;
// ln(m) = 2*f*p
mul.f64 %l_p, %l_p, %l_f;
add.f64 %l_p, %l_p, %l_p;
// ln(x) = n*ln(2) + ln(m)
mov.f64 %l_ln2, 0d3FE62E42FEFA39EF;
fma.rn.f64 %l_lnx, %l_nf, %l_ln2, %l_p;
// === exp(exponent * ln(x)) ===
mul.f64 %e_z, %exp64, %l_lnx;
mov.f64 %e_half, 0d3FE0000000000000;
fma.rn.f64 %e_nf, %e_z, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %e_z;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E21EED8EFF8D898;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %vr, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %vr, %vr, %e_nf;
st.global.f64 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const ABS_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry abs_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %va, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
abs.f32 %vr, %va;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SIGMOID_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry sigmoid_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %va, %vr, %neg, %e, %denom, %one, %lg2e;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
// sigmoid(x) = 1 / (1 + exp(-x))
neg.f32 %neg, %va;
mov.f32 %lg2e, 0f3FB8AA3B;
mul.f32 %neg, %neg, %lg2e;
ex2.approx.f32 %e, %neg;
mov.f32 %one, 0f3F800000;
add.f32 %denom, %one, %e;
div.rn.f32 %vr, %one, %denom;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SIGMOID_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry sigmoid_f64_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f64 %va, %vr, %e64, %denom, %one, %neg_x;
.reg .f64 %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f64 %va, [%a];
mov.f64 %one, 0d3FF0000000000000;
// sigmoid(x) = 1 / (1 + exp(-x))
neg.f64 %neg_x, %va;
// --- exp(%neg_x) via Cody-Waite + degree-11 Horner ---
mov.f64 %e_half, 0d3FE0000000000000;
fma.rn.f64 %e_nf, %neg_x, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E21EED8EFF8D898;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %e64, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %e64, %e64, %e_nf;
// --- end exp ---
add.f64 %denom, %one, %e64;
div.rn.f64 %vr, %one, %denom;
st.global.f64 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const TANH_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry tanh_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %va, %vr, %neg2x, %e, %denom, %sig, %one, %two, %lg2e;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
// tanh(x) = 2*sigmoid(2x) - 1
mov.f32 %two, 0f40000000;
mul.f32 %neg2x, %va, %two;
neg.f32 %neg2x, %neg2x;
mov.f32 %lg2e, 0f3FB8AA3B;
mul.f32 %neg2x, %neg2x, %lg2e;
ex2.approx.f32 %e, %neg2x;
mov.f32 %one, 0f3F800000;
add.f32 %denom, %one, %e;
div.rn.f32 %sig, %one, %denom;
mul.f32 %vr, %two, %sig;
sub.f32 %vr, %vr, %one;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const TANH_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry tanh_f64_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f64 %va, %vr, %e64, %num, %denom, %one, %two, %neg2x;
.reg .f64 %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f64 %va, [%a];
mov.f64 %one, 0d3FF0000000000000;
mov.f64 %two, 0d4000000000000000;
// tanh(x) = (1 - exp(-2x)) / (1 + exp(-2x))
mul.f64 %neg2x, %va, %two;
neg.f64 %neg2x, %neg2x;
// --- exp(%neg2x) via Cody-Waite + degree-11 Horner ---
mov.f64 %e_half, 0d3FE0000000000000;
fma.rn.f64 %e_nf, %neg2x, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg2x;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E21EED8EFF8D898;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %e64, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %e64, %e64, %e_nf;
// --- end exp ---
sub.f64 %num, %one, %e64;
add.f64 %denom, %one, %e64;
div.rn.f64 %vr, %num, %denom;
st.global.f64 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const FUSED_ADAM_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry fused_adam_kernel(
.param .u64 param_ptr,
.param .u64 grad_ptr,
.param .u64 exp_avg_ptr,
.param .u64 exp_avg_sq_ptr,
.param .f32 beta1,
.param .f32 beta2,
.param .f32 lr,
.param .f32 eps,
.param .f32 bc1,
.param .f32 bc2,
.param .f32 weight_decay,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %p, %g, %m, %v, %off;
.reg .f32 %vp, %vg, %vm, %vv;
.reg .f32 %b1, %b2, %f_lr, %f_eps, %f_bc1, %f_bc2, %f_wd;
.reg .f32 %t1, %t2, %m_hat, %v_hat, %denom, %update;
.reg .f32 %one;
.reg .pred %p_bound, %p_wd;
ld.param.u64 %p, [param_ptr];
ld.param.u64 %g, [grad_ptr];
ld.param.u64 %m, [exp_avg_ptr];
ld.param.u64 %v, [exp_avg_sq_ptr];
ld.param.f32 %b1, [beta1];
ld.param.f32 %b2, [beta2];
ld.param.f32 %f_lr, [lr];
ld.param.f32 %f_eps, [eps];
ld.param.f32 %f_bc1, [bc1];
ld.param.f32 %f_bc2, [bc2];
ld.param.f32 %f_wd, [weight_decay];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p_bound, %r_tid, %n_reg;
@%p_bound bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %p, %p, %off;
add.u64 %g, %g, %off;
add.u64 %m, %m, %off;
add.u64 %v, %v, %off;
ld.global.f32 %vp, [%p];
ld.global.f32 %vg, [%g];
ld.global.f32 %vm, [%m];
ld.global.f32 %vv, [%v];
// L2 weight decay: g = g + wd * p
mov.f32 %one, 0f00000000;
setp.gt.f32 %p_wd, %f_wd, %one;
@%p_wd fma.rn.f32 %vg, %f_wd, %vp, %vg;
// exp_avg = beta1 * exp_avg + (1 - beta1) * g
mov.f32 %one, 0f3F800000;
sub.f32 %t1, %one, %b1;
mul.f32 %vm, %vm, %b1;
fma.rn.f32 %vm, %t1, %vg, %vm;
// exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * g * g
sub.f32 %t2, %one, %b2;
mul.f32 %vv, %vv, %b2;
mul.f32 %t1, %vg, %vg;
fma.rn.f32 %vv, %t2, %t1, %vv;
// m_hat = exp_avg / bc1
div.rn.f32 %m_hat, %vm, %f_bc1;
// v_hat = exp_avg_sq / bc2
div.rn.f32 %v_hat, %vv, %f_bc2;
// denom = sqrt(v_hat) + eps
sqrt.rn.f32 %denom, %v_hat;
add.f32 %denom, %denom, %f_eps;
// param = param - lr * m_hat / denom
div.rn.f32 %update, %m_hat, %denom;
mul.f32 %update, %update, %f_lr;
sub.f32 %vp, %vp, %update;
st.global.f32 [%p], %vp;
st.global.f32 [%m], %vm;
st.global.f32 [%v], %vv;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const FUSED_GRU_FORWARD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry fused_gru_forward_kernel(
.param .u64 input_gates_ptr,
.param .u64 hidden_gates_ptr,
.param .u64 bias_ih_ptr,
.param .u64 bias_hh_ptr,
.param .u64 hx_ptr,
.param .u64 hy_ptr,
.param .u64 workspace_ptr,
.param .u32 hsz,
.param .u32 total
) {
.reg .u32 %tid, %bid, %bdim, %gdim, %total_reg, %hsz_reg;
.reg .u32 %idx, %stride, %offset3, %offset5, %hmod, %batch_idx;
.reg .u64 %ig, %hg, %b1, %b2, %hx, %hy, %ws;
.reg .u64 %off64, %tmp64;
.reg .f32 %ir, %ii, %in, %hr, %hi, %hn;
.reg .f32 %b1r, %b1i, %b1n, %b2r, %b2i, %b2n;
.reg .f32 %hx_val, %rg, %zg, %ng, %hy_val;
.reg .f32 %one, %neg_one, %exp_val, %denom, %tmp;
.reg .pred %p;
ld.param.u64 %ig, [input_gates_ptr];
ld.param.u64 %hg, [hidden_gates_ptr];
ld.param.u64 %b1, [bias_ih_ptr];
ld.param.u64 %b2, [bias_hh_ptr];
ld.param.u64 %hx, [hx_ptr];
ld.param.u64 %hy, [hy_ptr];
ld.param.u64 %ws, [workspace_ptr];
ld.param.u32 %hsz_reg, [hsz];
ld.param.u32 %total_reg, [total];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %tid, %tid.x;
mov.u32 %gdim, %nctaid.x;
mad.lo.u32 %idx, %bid, %bdim, %tid;
mul.lo.u32 %stride, %bdim, %gdim;
mov.f32 %one, 0f3F800000;
LOOP:
setp.ge.u32 %p, %idx, %total_reg;
@%p bra END;
// offset3 = (idx/hsz)*3*hsz + idx%hsz (into [B, 3*H] gates tensor)
div.u32 %batch_idx, %idx, %hsz_reg;
rem.u32 %hmod, %idx, %hsz_reg;
mul.lo.u32 %offset3, %batch_idx, %hsz_reg;
mul.lo.u32 %offset3, %offset3, 3;
add.u32 %offset3, %offset3, %hmod;
// Load input gate components: ir, ii, in
cvt.u64.u32 %off64, %offset3;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %ig, %off64;
ld.global.f32 %ir, [%tmp64];
cvt.u64.u32 %off64, %hsz_reg;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %tmp64, %off64;
ld.global.f32 %ii, [%tmp64];
add.u64 %tmp64, %tmp64, %off64;
ld.global.f32 %in, [%tmp64];
// Load hidden gate components: hr, hi, hn
cvt.u64.u32 %off64, %offset3;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %hg, %off64;
ld.global.f32 %hr, [%tmp64];
cvt.u64.u32 %off64, %hsz_reg;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %tmp64, %off64;
ld.global.f32 %hi, [%tmp64];
add.u64 %tmp64, %tmp64, %off64;
ld.global.f32 %hn, [%tmp64];
// Load biases (indexed by hmod, hmod+hsz, hmod+2*hsz)
cvt.u64.u32 %off64, %hmod;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %b1, %off64;
ld.global.f32 %b1r, [%tmp64];
cvt.u64.u32 %off64, %hsz_reg;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %tmp64, %off64;
ld.global.f32 %b1i, [%tmp64];
add.u64 %tmp64, %tmp64, %off64;
ld.global.f32 %b1n, [%tmp64];
cvt.u64.u32 %off64, %hmod;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %b2, %off64;
ld.global.f32 %b2r, [%tmp64];
cvt.u64.u32 %off64, %hsz_reg;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %tmp64, %off64;
ld.global.f32 %b2i, [%tmp64];
add.u64 %tmp64, %tmp64, %off64;
ld.global.f32 %b2n, [%tmp64];
// Load hx[idx]
cvt.u64.u32 %off64, %idx;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %hx, %off64;
ld.global.f32 %hx_val, [%tmp64];
// r = sigmoid(ir + hr + b1r + b2r)
add.f32 %rg, %ir, %hr;
add.f32 %rg, %rg, %b1r;
add.f32 %rg, %rg, %b2r;
neg.f32 %tmp, %rg;
mul.f32 %tmp, %tmp, 0f3FB8AA3B;
ex2.approx.f32 %exp_val, %tmp;
add.f32 %denom, %one, %exp_val;
div.rn.f32 %rg, %one, %denom;
// z = sigmoid(ii + hi + b1i + b2i)
add.f32 %zg, %ii, %hi;
add.f32 %zg, %zg, %b1i;
add.f32 %zg, %zg, %b2i;
neg.f32 %tmp, %zg;
mul.f32 %tmp, %tmp, 0f3FB8AA3B;
ex2.approx.f32 %exp_val, %tmp;
add.f32 %denom, %one, %exp_val;
div.rn.f32 %zg, %one, %denom;
// n = tanh(in + b1n + r*(hn + b2n))
add.f32 %tmp, %hn, %b2n;
fma.rn.f32 %ng, %rg, %tmp, %in;
add.f32 %ng, %ng, %b1n;
// tanh via 2*sigmoid(2x)-1
mul.f32 %tmp, %ng, 0f40000000;
neg.f32 %tmp, %tmp;
mul.f32 %tmp, %tmp, 0f3FB8AA3B;
ex2.approx.f32 %exp_val, %tmp;
add.f32 %denom, %one, %exp_val;
div.rn.f32 %ng, %one, %denom;
mul.f32 %ng, %ng, 0f40000000;
sub.f32 %ng, %ng, %one;
// hy = n + z * (hx - n)
sub.f32 %tmp, %hx_val, %ng;
fma.rn.f32 %hy_val, %zg, %tmp, %ng;
// Store hy[idx]
cvt.u64.u32 %off64, %idx;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %hy, %off64;
st.global.f32 [%tmp64], %hy_val;
// Store workspace: [r, z, n, hx, hn+b2n] at offset5 = (idx/hsz)*5*hsz + idx%hsz
mul.lo.u32 %offset5, %batch_idx, %hsz_reg;
mul.lo.u32 %offset5, %offset5, 5;
add.u32 %offset5, %offset5, %hmod;
cvt.u64.u32 %off64, %offset5;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %ws, %off64;
st.global.f32 [%tmp64], %rg;
cvt.u64.u32 %off64, %hsz_reg;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %tmp64, %off64;
st.global.f32 [%tmp64], %zg;
add.u64 %tmp64, %tmp64, %off64;
st.global.f32 [%tmp64], %ng;
add.u64 %tmp64, %tmp64, %off64;
st.global.f32 [%tmp64], %hx_val;
add.u64 %tmp64, %tmp64, %off64;
add.f32 %tmp, %hn, %b2n;
st.global.f32 [%tmp64], %tmp;
add.u32 %idx, %idx, %stride;
bra LOOP;
END:
ret;
}
";
#[cfg(feature = "cuda")]
fn launch_cfg(n: usize) -> GpuResult<LaunchConfig> {
if n > u32::MAX as usize {
return Err(GpuError::ShapeMismatch {
op: "kernel_launch",
expected: vec![u32::MAX as usize],
got: vec![n],
});
}
const BLOCK: u32 = 256;
let grid = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
Ok(LaunchConfig {
grid_dim: (grid.max(1), 1, 1),
block_dim: (BLOCK, 1, 1),
shared_mem_bytes: 0,
})
}
#[cfg(feature = "cuda")]
fn validate_binary(a: &CudaBuffer<f32>, b: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<()> {
if a.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: a.device_ordinal(),
got: device.ordinal(),
});
}
if b.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: b.device_ordinal(),
got: device.ordinal(),
});
}
if a.len() != b.len() {
return Err(GpuError::LengthMismatch {
a: a.len(),
b: b.len(),
});
}
Ok(())
}
#[cfg(feature = "cuda")]
fn validate_unary(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<()> {
if a.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: a.device_ordinal(),
got: device.ordinal(),
});
}
Ok(())
}
#[cfg(feature = "cuda")]
fn validate_device<T>(a: &CudaBuffer<T>, device: &GpuDevice) -> GpuResult<()> {
if a.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: a.device_ordinal(),
got: device.ordinal(),
});
}
Ok(())
}
#[cfg(feature = "cuda")]
fn try_launch_binary(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
device: &GpuDevice,
ptx_src: &'static str,
kernel_name: &'static str,
) -> GpuResult<Option<CudaBuffer<f32>>> {
use cudarc::driver::PushKernelArg;
let n = a.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
ptx_src,
kernel_name,
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => return Ok(None),
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(b.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(Some(out))
}
#[cfg(feature = "cuda")]
fn try_launch_binary_vec4(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
device: &GpuDevice,
ptx_src: &'static str,
kernel_name: &'static str,
) -> GpuResult<Option<CudaBuffer<f32>>> {
use cudarc::driver::PushKernelArg;
let n = a.len();
let n4 = (n / 4) as u32;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
ptx_src,
kernel_name,
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => return Ok(None),
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n4 as usize)?;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(b.inner())
.arg(out.inner_mut())
.arg(&n4)
.launch(cfg)?;
}
Ok(Some(out))
}
#[cfg(feature = "cuda")]
fn try_launch_unary(
a: &CudaBuffer<f32>,
device: &GpuDevice,
ptx_src: &'static str,
kernel_name: &'static str,
) -> GpuResult<Option<CudaBuffer<f32>>> {
use cudarc::driver::PushKernelArg;
let n = a.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
ptx_src,
kernel_name,
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => return Ok(None),
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(Some(out))
}
#[cfg(feature = "cuda")]
fn try_launch_binary_into(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
ptx_src: &'static str,
kernel_name: &'static str,
) -> GpuResult<bool> {
use cudarc::driver::PushKernelArg;
let n = a.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
ptx_src,
kernel_name,
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => return Ok(false),
};
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(b.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(true)
}
#[cfg(feature = "cuda")]
fn try_launch_unary_into(
a: &CudaBuffer<f32>,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
ptx_src: &'static str,
kernel_name: &'static str,
) -> GpuResult<bool> {
use cudarc::driver::PushKernelArg;
let n = a.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
ptx_src,
kernel_name,
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => return Ok(false),
};
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(true)
}
#[cfg(feature = "cuda")]
fn try_launch_binary_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
device: &GpuDevice,
ptx_src: &'static str,
kernel_name: &'static str,
) -> GpuResult<Option<CudaBuffer<f64>>> {
use cudarc::driver::PushKernelArg;
let n = a.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx, ptx_src, kernel_name, device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => return Ok(None),
};
let mut out = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(b.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(Some(out))
}
#[cfg(feature = "cuda")]
fn try_launch_unary_f64(
a: &CudaBuffer<f64>,
device: &GpuDevice,
ptx_src: &'static str,
kernel_name: &'static str,
) -> GpuResult<Option<CudaBuffer<f64>>> {
use cudarc::driver::PushKernelArg;
let n = a.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx, ptx_src, kernel_name, device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => return Ok(None),
};
let mut out = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(Some(out))
}
#[cfg(feature = "cuda")]
fn cpu_fallback_binary_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
device: &GpuDevice,
op: fn(f64, f64) -> f64,
) -> GpuResult<CudaBuffer<f64>> {
let a_host = gpu_to_cpu(a, device)?;
let b_host = gpu_to_cpu(b, device)?;
let result: Vec<f64> = a_host.iter().zip(b_host.iter()).map(|(&x, &y)| op(x, y)).collect();
cpu_to_gpu(&result, device)
}
#[cfg(feature = "cuda")]
fn cpu_fallback_unary_f64(
a: &CudaBuffer<f64>,
device: &GpuDevice,
op: fn(f64) -> f64,
) -> GpuResult<CudaBuffer<f64>> {
let a_host = gpu_to_cpu(a, device)?;
let result: Vec<f64> = a_host.iter().map(|&x| op(x)).collect();
cpu_to_gpu(&result, device)
}
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
fn try_launch_broadcast_binary_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
a_strides: &[u32],
b_strides: &[u32],
out_shape: &[u32],
out_numel: usize,
device: &GpuDevice,
ptx_src: &'static str,
kernel_name: &'static str,
) -> GpuResult<Option<CudaBuffer<f64>>> {
use cudarc::driver::PushKernelArg;
let ndim = out_shape.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
ptx_src,
kernel_name,
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => return Ok(None),
};
let a_str_buf = cpu_to_gpu(a_strides, device)?;
let b_str_buf = cpu_to_gpu(b_strides, device)?;
let shape_buf = cpu_to_gpu(out_shape, device)?;
let mut out = alloc_zeros_f64(out_numel, device)?;
let cfg = launch_cfg(out_numel)?;
let n_u32 = out_numel as u32;
let ndim_u32 = ndim as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(b.inner())
.arg(out.inner_mut())
.arg(a_str_buf.inner())
.arg(b_str_buf.inner())
.arg(shape_buf.inner())
.arg(&n_u32)
.arg(&ndim_u32)
.launch(cfg)?;
}
Ok(Some(out))
}
#[cfg(feature = "cuda")]
fn cpu_fallback_broadcast_binary_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
device: &GpuDevice,
op: fn(f64, f64) -> f64,
) -> GpuResult<CudaBuffer<f64>> {
let a_host = gpu_to_cpu(a, device)?;
let b_host = gpu_to_cpu(b, device)?;
let out_numel: usize = out_shape.iter().product();
let a_str = broadcast_strides(a_shape, out_shape);
let b_str = broadcast_strides(b_shape, out_shape);
let mut result = Vec::with_capacity(out_numel);
for i in 0..out_numel {
let mut remaining = i;
let mut a_idx = 0usize;
let mut b_idx = 0usize;
for d in (0..out_shape.len()).rev() {
let coord = remaining % out_shape[d];
remaining /= out_shape[d];
a_idx += coord * a_str[d] as usize;
b_idx += coord * b_str[d] as usize;
}
result.push(op(a_host[a_idx], b_host[b_idx]));
}
cpu_to_gpu(&result, device)
}
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
fn try_launch_broadcast_binary(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
a_strides: &[u32],
b_strides: &[u32],
out_shape: &[u32],
out_numel: usize,
device: &GpuDevice,
ptx_src: &'static str,
kernel_name: &'static str,
) -> GpuResult<Option<CudaBuffer<f32>>> {
use cudarc::driver::PushKernelArg;
let ndim = out_shape.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
ptx_src,
kernel_name,
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => return Ok(None),
};
let a_str_buf = cpu_to_gpu(a_strides, device)?;
let b_str_buf = cpu_to_gpu(b_strides, device)?;
let shape_buf = cpu_to_gpu(out_shape, device)?;
let mut out = alloc_zeros_f32(out_numel, device)?;
let cfg = launch_cfg(out_numel)?;
let n_u32 = out_numel as u32;
let ndim_u32 = ndim as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(b.inner())
.arg(out.inner_mut())
.arg(a_str_buf.inner())
.arg(b_str_buf.inner())
.arg(shape_buf.inner())
.arg(&n_u32)
.arg(&ndim_u32)
.launch(cfg)?;
}
Ok(Some(out))
}
#[cfg(feature = "cuda")]
fn broadcast_strides(in_shape: &[usize], out_shape: &[usize]) -> Vec<u32> {
let ndim = out_shape.len();
let in_ndim = in_shape.len();
let mut strides = vec![0u32; ndim];
let mut stride: u32 = 1;
for d in (0..ndim).rev() {
let in_d = if d + in_ndim >= ndim {
d + in_ndim - ndim
} else {
strides[d] = 0;
continue;
};
if in_shape[in_d] == 1 {
strides[d] = 0; } else {
strides[d] = stride;
}
stride *= in_shape[in_d] as u32;
}
strides
}
#[cfg(feature = "cuda")]
fn cpu_fallback_binary(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
device: &GpuDevice,
op: fn(f32, f32) -> f32,
) -> GpuResult<CudaBuffer<f32>> {
let a_host = gpu_to_cpu(a, device)?;
let b_host = gpu_to_cpu(b, device)?;
let result: Vec<f32> = a_host
.iter()
.zip(b_host.iter())
.map(|(&x, &y)| op(x, y))
.collect();
cpu_to_gpu(&result, device)
}
#[cfg(feature = "cuda")]
fn cpu_fallback_unary(
a: &CudaBuffer<f32>,
device: &GpuDevice,
op: fn(f32) -> f32,
) -> GpuResult<CudaBuffer<f32>> {
let a_host = gpu_to_cpu(a, device)?;
let result: Vec<f32> = a_host.iter().map(|&x| op(x)).collect();
cpu_to_gpu(&result, device)
}
#[cfg(feature = "cuda")]
pub fn gpu_add(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(a, b, device)?;
let n = a.len();
if n >= 16 && n % 4 == 0 {
if let Some(out) = try_launch_binary_vec4(
a, b, device, ADD_VEC4_PTX, "add_vec4_kernel",
)? {
return Ok(out);
}
}
if let Some(out) = try_launch_binary(a, b, device, ADD_PTX, "add_kernel")? {
return Ok(out);
}
cpu_fallback_binary(a, b, device, |x, y| x + y)
}
#[cfg(feature = "cuda")]
pub fn gpu_sub(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(a, b, device)?;
if let Some(out) = try_launch_binary(a, b, device, SUB_PTX, "sub_kernel")? {
return Ok(out);
}
cpu_fallback_binary(a, b, device, |x, y| x - y)
}
#[cfg(feature = "cuda")]
pub fn gpu_mul(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(a, b, device)?;
let n = a.len();
if n >= 16 && n % 4 == 0 {
if let Some(out) = try_launch_binary_vec4(
a, b, device, MUL_VEC4_PTX, "mul_vec4_kernel",
)? {
return Ok(out);
}
}
if let Some(out) = try_launch_binary(a, b, device, MUL_PTX, "mul_kernel")? {
return Ok(out);
}
cpu_fallback_binary(a, b, device, |x, y| x * y)
}
#[cfg(feature = "cuda")]
pub fn gpu_broadcast_add(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
let a_str = broadcast_strides(a_shape, out_shape);
let b_str = broadcast_strides(b_shape, out_shape);
let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
let out_numel: usize = out_shape.iter().product();
if let Some(out) = try_launch_broadcast_binary(
a,
b,
&a_str,
&b_str,
&shape_u32,
out_numel,
device,
BROADCAST_ADD_PTX,
"broadcast_add_kernel",
)? {
return Ok(out);
}
cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x + y)
}
#[cfg(feature = "cuda")]
pub fn gpu_broadcast_sub(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
let a_str = broadcast_strides(a_shape, out_shape);
let b_str = broadcast_strides(b_shape, out_shape);
let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
let out_numel: usize = out_shape.iter().product();
if let Some(out) = try_launch_broadcast_binary(
a,
b,
&a_str,
&b_str,
&shape_u32,
out_numel,
device,
BROADCAST_SUB_PTX,
"broadcast_sub_kernel",
)? {
return Ok(out);
}
cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x - y)
}
#[cfg(feature = "cuda")]
pub fn gpu_broadcast_mul(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
let a_str = broadcast_strides(a_shape, out_shape);
let b_str = broadcast_strides(b_shape, out_shape);
let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
let out_numel: usize = out_shape.iter().product();
if let Some(out) = try_launch_broadcast_binary(
a,
b,
&a_str,
&b_str,
&shape_u32,
out_numel,
device,
BROADCAST_MUL_PTX,
"broadcast_mul_kernel",
)? {
return Ok(out);
}
cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x * y)
}
#[cfg(feature = "cuda")]
pub fn gpu_broadcast_div(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
let a_str = broadcast_strides(a_shape, out_shape);
let b_str = broadcast_strides(b_shape, out_shape);
let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
let out_numel: usize = out_shape.iter().product();
if let Some(out) = try_launch_broadcast_binary(
a,
b,
&a_str,
&b_str,
&shape_u32,
out_numel,
device,
BROADCAST_DIV_PTX,
"broadcast_div_kernel",
)? {
return Ok(out);
}
cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x / y)
}
#[cfg(feature = "cuda")]
fn cpu_fallback_broadcast_binary(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
device: &GpuDevice,
op: fn(f32, f32) -> f32,
) -> GpuResult<CudaBuffer<f32>> {
let a_host = gpu_to_cpu(a, device)?;
let b_host = gpu_to_cpu(b, device)?;
let out_numel: usize = out_shape.iter().product();
let a_str = broadcast_strides(a_shape, out_shape);
let b_str = broadcast_strides(b_shape, out_shape);
let mut result = Vec::with_capacity(out_numel);
for i in 0..out_numel {
let mut remaining = i;
let mut a_idx = 0usize;
let mut b_idx = 0usize;
for d in (0..out_shape.len()).rev() {
let coord = remaining % out_shape[d];
remaining /= out_shape[d];
a_idx += coord * a_str[d] as usize;
b_idx += coord * b_str[d] as usize;
}
result.push(op(a_host[a_idx], b_host[b_idx]));
}
cpu_to_gpu(&result, device)
}
#[cfg(feature = "cuda")]
pub fn gpu_neg(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
validate_unary(a, device)?;
if let Some(out) = try_launch_unary(a, device, NEG_PTX, "neg_kernel")? {
return Ok(out);
}
cpu_fallback_unary(a, device, |x| -x)
}
#[cfg(feature = "cuda")]
pub fn gpu_relu(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
validate_unary(a, device)?;
if let Some(out) = try_launch_unary(a, device, RELU_PTX, "relu_kernel")? {
return Ok(out);
}
cpu_fallback_unary(a, device, |x| x.max(0.0))
}
#[cfg(feature = "cuda")]
pub fn gpu_relu_backward(
grad: &CudaBuffer<f32>,
input: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(grad, input, device)?;
if let Some(out) = try_launch_binary(
grad,
input,
device,
RELU_BACKWARD_PTX,
"relu_backward_kernel",
)? {
return Ok(out);
}
let grad_host = gpu_to_cpu(grad, device)?;
let input_host = gpu_to_cpu(input, device)?;
let result: Vec<f32> = grad_host
.iter()
.zip(input_host.iter())
.map(|(&g, &x)| if x > 0.0 { g } else { 0.0 })
.collect();
cpu_to_gpu(&result, device)
}
#[cfg(feature = "cuda")]
pub fn gpu_gelu_backward(
grad: &CudaBuffer<f32>,
input: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(grad, input, device)?;
if let Some(out) = try_launch_binary(
grad,
input,
device,
GELU_BACKWARD_PTX,
"gelu_backward_kernel",
)? {
return Ok(out);
}
let grad_host = gpu_to_cpu(grad, device)?;
let input_host = gpu_to_cpu(input, device)?;
let result: Vec<f32> = grad_host
.iter()
.zip(input_host.iter())
.map(|(&g, &x)| {
let k: f32 = 1.702;
let sig = 1.0 / (1.0 + (-k * x).exp());
g * (sig + k * x * sig * (1.0 - sig))
})
.collect();
cpu_to_gpu(&result, device)
}
#[cfg(feature = "cuda")]
pub fn gpu_gelu_backward_erf(
grad: &CudaBuffer<f32>,
input: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(grad, input, device)?;
if let Some(out) = try_launch_binary(
grad,
input,
device,
GELU_BACKWARD_ERF_PTX,
"gelu_backward_erf_kernel",
)? {
return Ok(out);
}
let grad_host = gpu_to_cpu(grad, device)?;
let input_host = gpu_to_cpu(input, device)?;
let inv_sqrt_2: f32 = std::f32::consts::FRAC_1_SQRT_2;
let inv_sqrt_2pi: f32 = 1.0 / (2.0 * std::f32::consts::PI).sqrt();
let result: Vec<f32> = grad_host
.iter()
.zip(input_host.iter())
.map(|(&g, &x)| {
let z = x * inv_sqrt_2;
let az = z.abs();
let t = 1.0 / (1.0 + 0.3275911 * az);
let poly = t * (0.2548296 + t * (-0.2844967 + t * (1.4214137 + t * (-1.4531520 + t * 0.3275911))));
let erf_abs = 1.0 - poly * (-az * az).exp();
let erf_val = if z >= 0.0 { erf_abs } else { -erf_abs };
let cdf = 0.5 * (1.0 + erf_val);
let pdf = inv_sqrt_2pi * (-0.5 * x * x).exp();
g * (cdf + x * pdf)
})
.collect();
cpu_to_gpu(&result, device)
}
#[cfg(feature = "cuda")]
pub fn gpu_index_select_1d(
input: &CudaBuffer<f32>,
indices: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let n = indices.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
INDEX_SELECT_1D_PTX,
"index_select_1d_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let input_host = gpu_to_cpu(input, device)?;
let indices_host = gpu_to_cpu(indices, device)?;
let result: Vec<f32> = indices_host
.iter()
.map(|&idx_f| input_host[idx_f as usize])
.collect();
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(indices.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_scatter_add_1d(
grad_output: &CudaBuffer<f32>,
indices: &CudaBuffer<f32>,
input_len: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(grad_output, device)?;
let n = grad_output.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
SCATTER_ADD_1D_PTX,
"scatter_add_1d_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let go_host = gpu_to_cpu(grad_output, device)?;
let idx_host = gpu_to_cpu(indices, device)?;
let mut result = vec![0.0f32; input_len];
for (i, &idx_f) in idx_host.iter().enumerate() {
result[idx_f as usize] += go_host[i];
}
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f32(input_len, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(grad_output.inner())
.arg(indices.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_masked_fill(
input: &CudaBuffer<f32>,
mask: &CudaBuffer<f32>,
value: f32,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_binary(input, mask, device)?;
let n = input.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
MASKED_FILL_PTX,
"masked_fill_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let input_host = gpu_to_cpu(input, device)?;
let mask_host = gpu_to_cpu(mask, device)?;
let result: Vec<f32> = input_host
.iter()
.zip(mask_host.iter())
.map(|(&x, &m)| if m >= 0.5 { value } else { x })
.collect();
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(mask.inner())
.arg(out.inner_mut())
.arg(&value)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_masked_zero(
grad: &CudaBuffer<f32>,
mask: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(grad, mask, device)?;
if let Some(out) = try_launch_binary(grad, mask, device, MASKED_ZERO_PTX, "masked_zero_kernel")?
{
return Ok(out);
}
let grad_host = gpu_to_cpu(grad, device)?;
let mask_host = gpu_to_cpu(mask, device)?;
let result: Vec<f32> = grad_host
.iter()
.zip(mask_host.iter())
.map(|(&g, &m)| if m >= 0.5 { 0.0 } else { g })
.collect();
cpu_to_gpu(&result, device)
}
#[cfg(feature = "cuda")]
pub fn gpu_sigmoid_backward(
grad: &CudaBuffer<f32>,
output: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(grad, output, device)?;
if let Some(out) = try_launch_binary(
grad,
output,
device,
SIGMOID_BACKWARD_PTX,
"sigmoid_backward_kernel",
)? {
return Ok(out);
}
let grad_host = gpu_to_cpu(grad, device)?;
let output_host = gpu_to_cpu(output, device)?;
let result: Vec<f32> = grad_host
.iter()
.zip(output_host.iter())
.map(|(&g, &o)| g * o * (1.0 - o))
.collect();
cpu_to_gpu(&result, device)
}
#[cfg(feature = "cuda")]
pub fn gpu_tanh_backward(
grad: &CudaBuffer<f32>,
output: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(grad, output, device)?;
if let Some(out) = try_launch_binary(
grad,
output,
device,
TANH_BACKWARD_PTX,
"tanh_backward_kernel",
)? {
return Ok(out);
}
let grad_host = gpu_to_cpu(grad, device)?;
let output_host = gpu_to_cpu(output, device)?;
let result: Vec<f32> = grad_host
.iter()
.zip(output_host.iter())
.map(|(&g, &o)| g * (1.0 - o * o))
.collect();
cpu_to_gpu(&result, device)
}
#[cfg(feature = "cuda")]
pub fn gpu_softmax_backward(
grad: &CudaBuffer<f32>,
output: &CudaBuffer<f32>,
cols: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_binary(grad, output, device)?;
let total = grad.len();
let rows = total / cols;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
SOFTMAX_BACKWARD_PTX,
"softmax_backward_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let grad_host = gpu_to_cpu(grad, device)?;
let output_host = gpu_to_cpu(output, device)?;
let mut result = vec![0.0f32; total];
for r in 0..rows {
let base = r * cols;
let mut dot = 0.0f32;
for c in 0..cols {
dot += grad_host[base + c] * output_host[base + c];
}
for c in 0..cols {
result[base + c] = output_host[base + c] * (grad_host[base + c] - dot);
}
}
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f32(total, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 4,
};
unsafe {
stream
.launch_builder(&f)
.arg(grad.inner())
.arg(output.inner())
.arg(out.inner_mut())
.arg(&rows_u32)
.arg(&cols_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_log_softmax(
input: &CudaBuffer<f32>,
cols: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let total = input.len();
let rows = total / cols;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
LOG_SOFTMAX_PTX,
"log_softmax_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(input, device)?;
let mut out = vec![0.0f32; total];
for r in 0..rows {
let base = r * cols;
let mut max_v = f32::NEG_INFINITY;
for c in 0..cols {
max_v = max_v.max(host[base + c]);
}
let mut sum_exp = 0.0f32;
for c in 0..cols {
sum_exp += (host[base + c] - max_v).exp();
}
let log_sum_exp = max_v + sum_exp.ln();
for c in 0..cols {
out[base + c] = host[base + c] - log_sum_exp;
}
}
return cpu_to_gpu(&out, device);
}
};
let mut out = alloc_zeros_f32(total, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 4,
};
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&rows_u32)
.arg(&cols_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_log_softmax_backward(
grad: &CudaBuffer<f32>,
output: &CudaBuffer<f32>,
cols: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_binary(grad, output, device)?;
let total = grad.len();
let rows = total / cols;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
LOG_SOFTMAX_BACKWARD_PTX,
"log_softmax_backward_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let grad_host = gpu_to_cpu(grad, device)?;
let output_host = gpu_to_cpu(output, device)?;
let mut result = vec![0.0f32; total];
for r in 0..rows {
let base = r * cols;
let mut sum_grad = 0.0f32;
for c in 0..cols {
sum_grad += grad_host[base + c];
}
for c in 0..cols {
result[base + c] =
grad_host[base + c] - output_host[base + c].exp() * sum_grad;
}
}
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f32(total, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 4,
};
unsafe {
stream
.launch_builder(&f)
.arg(grad.inner())
.arg(output.inner())
.arg(out.inner_mut())
.arg(&rows_u32)
.arg(&cols_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_reduce_sum(
a: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
let n = a.len();
if n == 0 {
return cpu_to_gpu(&[0.0f32], device);
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
REDUCE_SUM_PTX,
"reduce_sum_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(a, device)?;
let total: f32 = host.iter().sum();
return cpu_to_gpu(&[total], device);
}
};
const BLOCK: u32 = 256;
let num_blocks = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
let num_blocks = num_blocks.min(1024);
let mut partials = alloc_zeros_f32(num_blocks as usize, device)?;
let n_u32 = n as u32;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks.max(1), 1, 1),
block_dim: (BLOCK, 1, 1),
shared_mem_bytes: 0, };
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(partials.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
if num_blocks <= 1 {
return Ok(partials);
}
if num_blocks <= 256 {
let host_partials = gpu_to_cpu(&partials, device)?;
let total: f32 = host_partials.iter().sum();
return cpu_to_gpu(&[total], device);
}
gpu_reduce_sum(&partials, device)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_reduce_sum(
_a: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_sum_axis(
a: &CudaBuffer<f32>,
outer: usize,
axis_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(a, device)?;
let total_output = outer * inner;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
SUM_AXIS_PTX,
"sum_axis_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(a, device)?;
let mut result = vec![0.0f32; total_output];
for (i, out) in result.iter_mut().enumerate() {
let outer_idx = i / inner;
let inner_idx = i % inner;
let mut sum = 0.0f32;
for k in 0..axis_size {
sum += host[outer_idx * axis_size * inner + k * inner + inner_idx];
}
*out = sum;
}
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f32(total_output, device)?;
let cfg = launch_cfg(total_output)?;
let outer_u32 = outer as u32;
let axis_size_u32 = axis_size as u32;
let inner_u32 = inner as u32;
let total_u32 = total_output as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(out.inner_mut())
.arg(&outer_u32)
.arg(&axis_size_u32)
.arg(&inner_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_cumsum(
input: &CudaBuffer<f32>,
outer: usize,
dim_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let total = outer * dim_size * inner;
let num_threads = outer * inner;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
CUMSUM_PTX,
"cumsum_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(input, device)?;
let mut result = vec![0.0f32; total];
for i in 0..num_threads {
let outer_idx = i / inner;
let inner_idx = i % inner;
let base = outer_idx * dim_size * inner + inner_idx;
let mut acc = 0.0f32;
for k in 0..dim_size {
let idx = base + k * inner;
acc += host[idx];
result[idx] = acc;
}
}
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f32(total, device)?;
let cfg = launch_cfg(num_threads)?;
let outer_u32 = outer as u32;
let dim_size_u32 = dim_size as u32;
let inner_u32 = inner as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&outer_u32)
.arg(&dim_size_u32)
.arg(&inner_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_cumprod(
input: &CudaBuffer<f32>,
outer: usize,
dim_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let total = outer * dim_size * inner;
let num_threads = outer * inner;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
CUMPROD_PTX,
"cumprod_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(input, device)?;
let mut result = vec![0.0f32; total];
for i in 0..num_threads {
let outer_idx = i / inner;
let inner_idx = i % inner;
let base = outer_idx * dim_size * inner + inner_idx;
let mut acc = 1.0f32;
for k in 0..dim_size {
let idx = base + k * inner;
acc *= host[idx];
result[idx] = acc;
}
}
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f32(total, device)?;
let cfg = launch_cfg(num_threads)?;
let outer_u32 = outer as u32;
let dim_size_u32 = dim_size as u32;
let inner_u32 = inner as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&outer_u32)
.arg(&dim_size_u32)
.arg(&inner_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_cummax(
input: &CudaBuffer<f32>,
outer: usize,
dim_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let total = outer * dim_size * inner;
let num_threads = outer * inner;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
CUMMAX_PTX,
"cummax_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(input, device)?;
let mut vals = vec![0.0f32; total];
let mut idxs = vec![0.0f32; total];
for i in 0..num_threads {
let outer_idx = i / inner;
let inner_idx = i % inner;
let base = outer_idx * dim_size * inner + inner_idx;
let mut acc = f32::NEG_INFINITY;
let mut best = 0u32;
for k in 0..dim_size {
let idx = base + k * inner;
if host[idx] > acc {
acc = host[idx];
best = k as u32;
}
vals[idx] = acc;
idxs[idx] = best as f32;
}
}
return Ok((cpu_to_gpu(&vals, device)?, cpu_to_gpu(&idxs, device)?));
}
};
let mut out = alloc_zeros_f32(total, device)?;
let mut out_idx = alloc_zeros_f32(total, device)?;
let cfg = launch_cfg(num_threads)?;
let outer_u32 = outer as u32;
let dim_size_u32 = dim_size as u32;
let inner_u32 = inner as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(out_idx.inner_mut())
.arg(&outer_u32)
.arg(&dim_size_u32)
.arg(&inner_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok((out, out_idx))
}
#[cfg(feature = "cuda")]
pub fn gpu_cummin(
input: &CudaBuffer<f32>,
outer: usize,
dim_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let total = outer * dim_size * inner;
let num_threads = outer * inner;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
CUMMIN_PTX,
"cummin_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(input, device)?;
let mut vals = vec![0.0f32; total];
let mut idxs = vec![0.0f32; total];
for i in 0..num_threads {
let outer_idx = i / inner;
let inner_idx = i % inner;
let base = outer_idx * dim_size * inner + inner_idx;
let mut acc = f32::INFINITY;
let mut best = 0u32;
for k in 0..dim_size {
let idx = base + k * inner;
if host[idx] < acc {
acc = host[idx];
best = k as u32;
}
vals[idx] = acc;
idxs[idx] = best as f32;
}
}
return Ok((cpu_to_gpu(&vals, device)?, cpu_to_gpu(&idxs, device)?));
}
};
let mut out = alloc_zeros_f32(total, device)?;
let mut out_idx = alloc_zeros_f32(total, device)?;
let cfg = launch_cfg(num_threads)?;
let outer_u32 = outer as u32;
let dim_size_u32 = dim_size as u32;
let inner_u32 = inner as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(out_idx.inner_mut())
.arg(&outer_u32)
.arg(&dim_size_u32)
.arg(&inner_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok((out, out_idx))
}
#[cfg(feature = "cuda")]
pub fn gpu_logcumsumexp(
input: &CudaBuffer<f32>,
outer: usize,
dim_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let total = outer * dim_size * inner;
let num_threads = outer * inner;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
LOGCUMSUMEXP_PTX,
"logcumsumexp_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(input, device)?;
let mut result = vec![0.0f32; total];
for i in 0..num_threads {
let outer_idx = i / inner;
let inner_idx = i % inner;
let base = outer_idx * dim_size * inner + inner_idx;
let mut acc = f32::NEG_INFINITY;
for k in 0..dim_size {
let idx = base + k * inner;
let x = host[idx];
let m = acc.max(x);
acc = m + ((acc - m).exp() + (x - m).exp()).ln();
result[idx] = acc;
}
}
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f32(total, device)?;
let cfg = launch_cfg(num_threads)?;
let outer_u32 = outer as u32;
let dim_size_u32 = dim_size as u32;
let inner_u32 = inner as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&outer_u32)
.arg(&dim_size_u32)
.arg(&inner_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_strided_split(
input: &CudaBuffer<f32>,
total_along_axis: usize,
split_offset: usize,
split_size: usize,
inner_size: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
STRIDED_SPLIT_PTX,
"strided_split_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(input, device)?;
let outer = n / (split_size * inner_size);
let mut result = vec![0.0f32; n];
for (i, out) in result.iter_mut().enumerate() {
let outer_idx = i / (split_size * inner_size);
let within = i % (split_size * inner_size);
let src_idx =
outer_idx * total_along_axis * inner_size + split_offset * inner_size + within;
*out = host[src_idx];
}
let _ = outer;
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let total_ax_u32 = total_along_axis as u32;
let offset_u32 = split_offset as u32;
let split_sz_u32 = split_size as u32;
let inner_u32 = inner_size as u32;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&total_ax_u32)
.arg(&offset_u32)
.arg(&split_sz_u32)
.arg(&inner_u32)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub fn gpu_strided_cat(
input: &CudaBuffer<f32>,
output: &mut CudaBuffer<f32>,
total_along_axis: usize,
cat_offset: usize,
part_size: usize,
inner_size: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
STRIDED_CAT_PTX,
"strided_cat_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host_in = gpu_to_cpu(input, device)?;
let mut host_out = gpu_to_cpu(output, device)?;
for (i, &val) in host_in.iter().enumerate().take(n) {
let outer_idx = i / (part_size * inner_size);
let within = i % (part_size * inner_size);
let dst_idx =
outer_idx * total_along_axis * inner_size + cat_offset * inner_size + within;
host_out[dst_idx] = val;
}
*output = cpu_to_gpu(&host_out, device)?;
return Ok(());
}
};
let cfg = launch_cfg(n)?;
let total_ax_u32 = total_along_axis as u32;
let offset_u32 = cat_offset as u32;
let part_sz_u32 = part_size as u32;
let inner_u32 = inner_size as u32;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(output.inner_mut())
.arg(&total_ax_u32)
.arg(&offset_u32)
.arg(&part_sz_u32)
.arg(&inner_u32)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(())
}
pub const STRIDED_COPY_MAX_DIMS: usize = 8;
#[cfg(feature = "cuda")]
fn pad_strided_copy_params(
out_shape: &[usize],
src_strides: &[isize],
n: usize,
) -> GpuResult<([u32; STRIDED_COPY_MAX_DIMS], [u32; STRIDED_COPY_MAX_DIMS])> {
if out_shape.len() != src_strides.len() {
return Err(GpuError::ShapeMismatch {
op: "strided_copy_pad",
expected: vec![out_shape.len()],
got: vec![src_strides.len()],
});
}
if out_shape.len() > STRIDED_COPY_MAX_DIMS {
return Err(GpuError::ShapeMismatch {
op: "strided_copy_pad",
expected: vec![STRIDED_COPY_MAX_DIMS],
got: vec![out_shape.len()],
});
}
for &s in src_strides {
if s < 0 {
return Err(GpuError::ShapeMismatch {
op: "strided_copy_pad_negative_stride",
expected: vec![0],
got: vec![s.unsigned_abs()],
});
}
}
let rank = out_shape.len();
let mut out_stride = [0u32; STRIDED_COPY_MAX_DIMS];
if rank > 0 {
let mut acc: usize = 1;
for d in (0..rank).rev() {
if acc > u32::MAX as usize {
return Err(GpuError::ShapeMismatch {
op: "strided_copy_stride_overflow",
expected: vec![u32::MAX as usize],
got: vec![acc],
});
}
out_stride[d] = acc as u32;
acc = acc.saturating_mul(out_shape[d]);
}
}
let pad_val = (n as u32).saturating_add(1).max(1);
for d in rank..STRIDED_COPY_MAX_DIMS {
out_stride[d] = pad_val;
}
let mut src_stride_out = [0u32; STRIDED_COPY_MAX_DIMS];
for d in 0..rank {
let s = src_strides[d];
if s as usize > u32::MAX as usize {
return Err(GpuError::ShapeMismatch {
op: "strided_copy_src_stride_overflow",
expected: vec![u32::MAX as usize],
got: vec![s as usize],
});
}
src_stride_out[d] = s as u32;
}
Ok((out_stride, src_stride_out))
}
#[cfg(feature = "cuda")]
pub fn gpu_strided_copy(
input: &CudaBuffer<f32>,
out_shape: &[usize],
src_strides: &[isize],
src_offset: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let n: usize = out_shape.iter().product();
let (out_stride, src_stride) = pad_strided_copy_params(out_shape, src_strides, n)?;
if n == 0 {
return alloc_zeros_f32(0, device);
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
STRIDED_COPY_PTX,
"strided_copy_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(input, device)?;
let mut result = vec![0.0f32; n];
for i in 0..n {
let mut flat = i as u32;
let mut src_idx = src_offset as u32;
for d in 0..STRIDED_COPY_MAX_DIMS {
let os = out_stride[d];
let ss = src_stride[d];
let coord = flat / os;
flat -= coord * os;
src_idx += coord * ss;
}
result[i] = host[src_idx as usize];
}
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let src_offset_u32 = src_offset as u32;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&src_offset_u32)
.arg(&n_u32)
.arg(&out_stride[0])
.arg(&out_stride[1])
.arg(&out_stride[2])
.arg(&out_stride[3])
.arg(&out_stride[4])
.arg(&out_stride[5])
.arg(&out_stride[6])
.arg(&out_stride[7])
.arg(&src_stride[0])
.arg(&src_stride[1])
.arg(&src_stride[2])
.arg(&src_stride[3])
.arg(&src_stride[4])
.arg(&src_stride[5])
.arg(&src_stride[6])
.arg(&src_stride[7])
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_strided_copy_f64(
input: &CudaBuffer<f64>,
out_shape: &[usize],
src_strides: &[isize],
src_offset: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
validate_device(input, device)?;
let n: usize = out_shape.iter().product();
let (out_stride, src_stride) = pad_strided_copy_params(out_shape, src_strides, n)?;
if n == 0 {
return alloc_zeros_f64(0, device);
}
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
STRIDED_COPY_PTX,
"strided_copy_kernel",
"strided_copy_f64_kernel",
);
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"strided_copy_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(input, device)?;
let mut result = vec![0.0f64; n];
for i in 0..n {
let mut flat = i as u32;
let mut src_idx = src_offset as u32;
for d in 0..STRIDED_COPY_MAX_DIMS {
let os = out_stride[d];
let ss = src_stride[d];
let coord = flat / os;
flat -= coord * os;
src_idx += coord * ss;
}
result[i] = host[src_idx as usize];
}
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(n)?;
let src_offset_u32 = src_offset as u32;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&src_offset_u32)
.arg(&n_u32)
.arg(&out_stride[0])
.arg(&out_stride[1])
.arg(&out_stride[2])
.arg(&out_stride[3])
.arg(&out_stride[4])
.arg(&out_stride[5])
.arg(&out_stride[6])
.arg(&out_stride[7])
.arg(&src_stride[0])
.arg(&src_stride[1])
.arg(&src_stride[2])
.arg(&src_stride[3])
.arg(&src_stride[4])
.arg(&src_stride[5])
.arg(&src_stride[6])
.arg(&src_stride[7])
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_scale(
a: &CudaBuffer<f32>,
scalar: f32,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(a, device)?;
let n = a.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
SCALE_PTX,
"scale_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(a, device)?;
let result: Vec<f32> = host.iter().map(|&x| x * scalar).collect();
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(out.inner_mut())
.arg(&scalar)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_softmax(
input: &CudaBuffer<f32>,
rows: usize,
cols: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
SOFTMAX_PTX,
"softmax_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(input, device)?;
let mut out = vec![0.0f32; host.len()];
for r in 0..rows {
let base = r * cols;
let mut max_v = f32::NEG_INFINITY;
for c in 0..cols {
max_v = max_v.max(host[base + c]);
}
let mut sum = 0.0f32;
for c in 0..cols {
let e = (host[base + c] - max_v).exp();
out[base + c] = e;
sum += e;
}
let inv = 1.0 / sum;
for c in 0..cols {
out[base + c] *= inv;
}
}
return cpu_to_gpu(&out, device);
}
};
let mut out = alloc_zeros_f32(rows * cols, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 4, };
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&rows_u32)
.arg(&cols_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_dropout(
input: &CudaBuffer<f32>,
threshold: u32,
scale: f32,
seed: u32,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let n = input.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
DROPOUT_PTX,
"dropout_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(input, device)?;
let result: Vec<f32> = host
.iter()
.enumerate()
.map(|(i, &x)| {
let mut r = (i as u32).wrapping_mul(2654435761) ^ seed;
r ^= r << 13;
r ^= r >> 17;
r ^= r << 5;
if r < threshold { 0.0 } else { x * scale }
})
.collect();
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.arg(&threshold)
.arg(&scale)
.arg(&seed)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_dropout_f64(
input: &CudaBuffer<f64>,
threshold: u32,
scale: f64,
seed: u32,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let n = input.len();
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, DROPOUT_PTX, "dropout_kernel", "dropout_f64_kernel");
let f = match crate::module_cache::get_or_compile(
ctx, ptx, "dropout_f64_kernel", device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(input, device)?;
let result: Vec<f64> = host
.iter()
.enumerate()
.map(|(i, &x)| {
let mut r = (i as u32).wrapping_mul(2654435761) ^ seed;
r ^= r << 13;
r ^= r >> 17;
r ^= r << 5;
if r < threshold { 0.0 } else { x * scale }
})
.collect();
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.arg(&threshold)
.arg(&scale)
.arg(&seed)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_dropout_f64(_input: &CudaBuffer<f64>, _threshold: u32, _scale: f64, _seed: u32, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(feature = "cuda")]
pub fn gpu_transpose_2d(
input: &CudaBuffer<f32>,
m: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let total = m * n;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
TRANSPOSE_2D_PTX,
"transpose_2d_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(input, device)?;
let mut out = vec![0.0f32; total];
for i in 0..m {
for j in 0..n {
out[j * m + i] = host[i * n + j];
}
}
return cpu_to_gpu(&out, device);
}
};
let mut out = alloc_zeros_f32(total, device)?;
let cfg = launch_cfg(total)?;
let m_u32 = m as u32;
let n_u32 = n as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&m_u32)
.arg(&n_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_permute_0213(
input: &CudaBuffer<f32>,
d0: usize,
d1: usize,
d2: usize,
d3: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let total = d0 * d1 * d2 * d3;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
PERMUTE_0213_PTX,
"permute_0213_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(input, device)?;
let mut out = vec![0.0f32; total];
for i0 in 0..d0 {
for i1 in 0..d1 {
for i2 in 0..d2 {
for i3 in 0..d3 {
let in_idx = ((i0 * d1 + i1) * d2 + i2) * d3 + i3;
let out_idx = ((i0 * d2 + i2) * d1 + i1) * d3 + i3;
out[out_idx] = host[in_idx];
}
}
}
}
return cpu_to_gpu(&out, device);
}
};
let mut out = alloc_zeros_f32(total, device)?;
let cfg = launch_cfg(total)?;
let d0_u32 = d0 as u32;
let d1_u32 = d1 as u32;
let d2_u32 = d2 as u32;
let d3_u32 = d3 as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&d0_u32)
.arg(&d1_u32)
.arg(&d2_u32)
.arg(&d3_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_small_matmul(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
let total = m * n;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
SMALL_MATMUL_PTX,
"small_matmul_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
return crate::blas::gpu_matmul_f32(a, b, m, k, n, device);
}
};
let mut c = alloc_zeros_f32(total, device)?;
let cfg = launch_cfg(total)?;
let m_u32 = m as u32;
let k_u32 = k as u32;
let n_u32 = n as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(b.inner())
.arg(c.inner_mut())
.arg(&m_u32)
.arg(&k_u32)
.arg(&n_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(c)
}
#[cfg(feature = "cuda")]
pub fn gpu_small_bmm(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
batch: usize,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
if batch == 1 {
return gpu_small_matmul(a, b, m, k, n, device);
}
crate::blas::gpu_bmm_f32(a, b, batch, m, k, n, device)
}
#[cfg(feature = "cuda")]
pub fn gpu_embed_lookup(
idx: &CudaBuffer<f32>,
weight: &CudaBuffer<f32>,
d: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
EMBED_LOOKUP_PTX,
"embed_lookup_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let idx_host = gpu_to_cpu(idx, device)?;
let weight_host = gpu_to_cpu(weight, device)?;
let row = idx_host[0] as usize;
let start = row * d;
let out = weight_host[start..start + d].to_vec();
return cpu_to_gpu(&out, device);
}
};
let mut out = alloc_zeros_f32(d, device)?;
let cfg = launch_cfg(d)?;
let d_u32 = d as u32;
unsafe {
stream
.launch_builder(&f)
.arg(idx.inner())
.arg(weight.inner())
.arg(out.inner_mut())
.arg(&d_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_slice_write(
src: &CudaBuffer<f32>,
dst: &mut CudaBuffer<f32>,
n_batch: usize,
d: usize,
max_len: usize,
pos: usize,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
let total = n_batch * d;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
SLICE_WRITE_PTX,
"slice_write_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let src_host = gpu_to_cpu(src, device)?;
let mut dst_host = gpu_to_cpu(dst, device)?;
for b in 0..n_batch {
for di in 0..d {
dst_host[b * max_len * d + pos * d + di] = src_host[b * d + di];
}
}
let new_dst = cpu_to_gpu(&dst_host, device)?;
*dst = new_dst;
return Ok(());
}
};
let cfg = launch_cfg(total)?;
let n_u32 = total as u32;
let d_u32 = d as u32;
let max_len_u32 = max_len as u32;
let pos_u32 = pos as u32;
unsafe {
stream
.launch_builder(&f)
.arg(src.inner())
.arg(dst.inner_mut())
.arg(&n_u32)
.arg(&d_u32)
.arg(&max_len_u32)
.arg(&pos_u32)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gpu_slice_read(
src: &CudaBuffer<f32>,
n_batch: usize,
d: usize,
len: usize,
max_len: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
let total = n_batch * len * d;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
SLICE_READ_PTX,
"slice_read_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(src, device)?;
let mut out = vec![0.0f32; total];
for b in 0..n_batch {
for r in 0..len {
for di in 0..d {
out[b * len * d + r * d + di] = host[b * max_len * d + r * d + di];
}
}
}
return cpu_to_gpu(&out, device);
}
};
let mut out = alloc_zeros_f32(total, device)?;
let cfg = launch_cfg(total)?;
let total_u32 = total as u32;
let d_u32 = d as u32;
let len_u32 = len as u32;
let max_len_u32 = max_len as u32;
unsafe {
stream
.launch_builder(&f)
.arg(src.inner())
.arg(out.inner_mut())
.arg(&total_u32)
.arg(&d_u32)
.arg(&len_u32)
.arg(&max_len_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_gelu(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
validate_unary(input, device)?;
if let Some(out) = try_launch_unary(input, device, GELU_PTX, "gelu_kernel")? {
return Ok(out);
}
cpu_fallback_unary(input, device, |x| {
let s = 1.0 / (1.0 + (-1.702 * x).exp());
x * s
})
}
#[cfg(feature = "cuda")]
pub fn gpu_gelu_tanh(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
validate_unary(input, device)?;
if let Some(out) = try_launch_unary(input, device, GELU_TANH_PTX, "gelu_tanh_kernel")? {
return Ok(out);
}
cpu_fallback_unary(input, device, |x| {
let sqrt_2_over_pi: f32 = 0.7978845608;
let c: f32 = 0.044715;
let inner = sqrt_2_over_pi * (x + c * x * x * x);
0.5 * x * (1.0 + inner.tanh())
})
}
#[cfg(feature = "cuda")]
pub fn gpu_gelu_erf(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
validate_unary(input, device)?;
if let Some(out) = try_launch_unary(input, device, GELU_ERF_PTX, "gelu_erf_kernel")? {
return Ok(out);
}
cpu_fallback_unary(input, device, |x| {
let z = x * std::f32::consts::FRAC_1_SQRT_2;
let az = z.abs();
let t = 1.0 / (1.0 + 0.3275911 * az);
let poly = t * (0.254829592 + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
let erf_abs = 1.0 - poly * (-az * az).exp();
let erf_val = if z < 0.0 { -erf_abs } else { erf_abs };
x * 0.5 * (1.0 + erf_val)
})
}
#[cfg(feature = "cuda")]
pub fn gpu_gelu_backward_tanh(
grad: &CudaBuffer<f32>,
input: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(grad, input, device)?;
if let Some(out) = try_launch_binary(
grad,
input,
device,
GELU_BACKWARD_TANH_PTX,
"gelu_backward_tanh_kernel",
)? {
return Ok(out);
}
let grad_host = gpu_to_cpu(grad, device)?;
let input_host = gpu_to_cpu(input, device)?;
let result: Vec<f32> = grad_host
.iter()
.zip(input_host.iter())
.map(|(&g, &x)| {
let sqrt_2_over_pi: f32 = 0.7978845608;
let c: f32 = 0.044715;
let c3: f32 = 0.134145;
let u = sqrt_2_over_pi * (x + c * x * x * x);
let t = u.tanh();
let dt = 1.0 - t * t;
let d_inner = sqrt_2_over_pi * (1.0 + c3 * x * x);
g * (0.5 * (1.0 + t) + 0.5 * x * dt * d_inner)
})
.collect();
cpu_to_gpu(&result, device)
}
#[cfg(feature = "cuda")]
pub fn gpu_silu(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
validate_unary(input, device)?;
if let Some(out) = try_launch_unary(input, device, SILU_PTX, "silu_kernel")? {
return Ok(out);
}
cpu_fallback_unary(input, device, |x| {
let sig = 1.0 / (1.0 + (-x).exp());
x * sig
})
}
#[cfg(feature = "cuda")]
pub fn gpu_silu_backward(
grad: &CudaBuffer<f32>,
input: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(grad, input, device)?;
if let Some(out) = try_launch_binary(
grad,
input,
device,
SILU_BACKWARD_PTX,
"silu_backward_kernel",
)? {
return Ok(out);
}
let grad_host = gpu_to_cpu(grad, device)?;
let input_host = gpu_to_cpu(input, device)?;
let result: Vec<f32> = grad_host
.iter()
.zip(input_host.iter())
.map(|(&g, &x)| {
let sig = 1.0 / (1.0 + (-x).exp());
g * (sig + x * sig * (1.0 - sig))
})
.collect();
cpu_to_gpu(&result, device)
}
#[cfg(feature = "cuda")]
pub fn gpu_elu(
input: &CudaBuffer<f32>,
alpha: f32,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let n = input.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
ELU_PTX,
"elu_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(input, device)?;
let result: Vec<f32> = host
.iter()
.map(|&x| if x > 0.0 { x } else { alpha * (x.exp() - 1.0) })
.collect();
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.arg(&alpha)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_elu_backward(
grad: &CudaBuffer<f32>,
input: &CudaBuffer<f32>,
alpha: f32,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_binary(grad, input, device)?;
let n = grad.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
ELU_BACKWARD_PTX,
"elu_backward_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let grad_host = gpu_to_cpu(grad, device)?;
let input_host = gpu_to_cpu(input, device)?;
let result: Vec<f32> = grad_host
.iter()
.zip(input_host.iter())
.map(|(&g, &x)| if x > 0.0 { g } else { g * alpha * x.exp() })
.collect();
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(grad.inner())
.arg(input.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.arg(&alpha)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_mish(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
validate_unary(input, device)?;
if let Some(out) = try_launch_unary(input, device, MISH_PTX, "mish_kernel")? {
return Ok(out);
}
cpu_fallback_unary(input, device, |x| {
let sp = if x > 20.0 { x } else { (1.0 + x.exp()).ln() };
x * sp.tanh()
})
}
#[cfg(feature = "cuda")]
pub fn gpu_mish_backward(
grad: &CudaBuffer<f32>,
input: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(grad, input, device)?;
if let Some(out) = try_launch_binary(
grad,
input,
device,
MISH_BACKWARD_PTX,
"mish_backward_kernel",
)? {
return Ok(out);
}
let grad_host = gpu_to_cpu(grad, device)?;
let input_host = gpu_to_cpu(input, device)?;
let result: Vec<f32> = grad_host
.iter()
.zip(input_host.iter())
.map(|(&g, &x)| {
let sp = if x > 20.0 { x } else { (1.0 + x.exp()).ln() };
let t = sp.tanh();
let sig = 1.0 / (1.0 + (-x).exp());
g * (t + x * sig * (1.0 - t * t))
})
.collect();
cpu_to_gpu(&result, device)
}
#[cfg(feature = "cuda")]
pub fn gpu_clamp(
input: &CudaBuffer<f32>,
min_val: f32,
max_val: f32,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let n = input.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
CLAMP_PTX,
"clamp_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(input, device)?;
let result: Vec<f32> = host
.iter()
.map(|&x| x.max(min_val).min(max_val))
.collect();
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.arg(&min_val)
.arg(&max_val)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_div(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(a, b, device)?;
if let Some(out) = try_launch_binary(a, b, device, DIV_PTX, "div_kernel")? {
return Ok(out);
}
let a_host = gpu_to_cpu(a, device)?;
let b_host = gpu_to_cpu(b, device)?;
let result: Vec<f32> = a_host
.iter()
.zip(b_host.iter())
.map(|(&x, &y)| x / y)
.collect();
cpu_to_gpu(&result, device)
}
#[cfg(feature = "cuda")]
pub fn gpu_exp(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
validate_unary(a, device)?;
if let Some(out) = try_launch_unary(a, device, EXP_PTX, "exp_kernel")? {
return Ok(out);
}
cpu_fallback_unary(a, device, |x| x.exp())
}
#[cfg(feature = "cuda")]
pub fn gpu_log(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
validate_unary(a, device)?;
if let Some(out) = try_launch_unary(a, device, LOG_PTX, "log_kernel")? {
return Ok(out);
}
cpu_fallback_unary(a, device, |x| x.ln())
}
#[cfg(feature = "cuda")]
pub fn gpu_sqrt(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
validate_unary(a, device)?;
if let Some(out) = try_launch_unary(a, device, SQRT_PTX, "sqrt_kernel")? {
return Ok(out);
}
cpu_fallback_unary(a, device, |x| x.sqrt())
}
#[cfg(feature = "cuda")]
pub fn gpu_pow(
a: &CudaBuffer<f32>,
exponent: f32,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(a, device)?;
let n = a.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
POW_PTX,
"pow_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(a, device)?;
let result: Vec<f32> = host.iter().map(|&x| x.powf(exponent)).collect();
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(out.inner_mut())
.arg(&exponent)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_abs(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
validate_unary(a, device)?;
if let Some(out) = try_launch_unary(a, device, ABS_PTX, "abs_kernel")? {
return Ok(out);
}
cpu_fallback_unary(a, device, |x| x.abs())
}
#[cfg(feature = "cuda")]
pub fn gpu_sigmoid(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
validate_unary(a, device)?;
if let Some(out) = try_launch_unary(a, device, SIGMOID_PTX, "sigmoid_kernel")? {
return Ok(out);
}
cpu_fallback_unary(a, device, |x| 1.0 / (1.0 + (-x).exp()))
}
#[cfg(feature = "cuda")]
pub fn gpu_tanh(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
validate_unary(a, device)?;
if let Some(out) = try_launch_unary(a, device, TANH_PTX, "tanh_kernel")? {
return Ok(out);
}
cpu_fallback_unary(a, device, |x| x.tanh())
}
#[cfg(feature = "cuda")]
pub fn gpu_add_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
if a.len() != b.len() {
return Err(GpuError::LengthMismatch { a: a.len(), b: b.len() });
}
let ptx = get_f64_ptx(&CACHE, ADD_PTX, "add_kernel", "add_f64_kernel");
if let Some(out) = try_launch_binary_f64(a, b, device, ptx, "add_f64_kernel")? {
return Ok(out);
}
cpu_fallback_binary_f64(a, b, device, |x, y| x + y)
}
#[cfg(feature = "cuda")]
pub fn gpu_sub_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
if a.len() != b.len() {
return Err(GpuError::LengthMismatch { a: a.len(), b: b.len() });
}
let ptx = get_f64_ptx(&CACHE, SUB_PTX, "sub_kernel", "sub_f64_kernel");
if let Some(out) = try_launch_binary_f64(a, b, device, ptx, "sub_f64_kernel")? {
return Ok(out);
}
cpu_fallback_binary_f64(a, b, device, |x, y| x - y)
}
#[cfg(feature = "cuda")]
pub fn gpu_mul_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
if a.len() != b.len() {
return Err(GpuError::LengthMismatch { a: a.len(), b: b.len() });
}
let ptx = get_f64_ptx(&CACHE, MUL_PTX, "mul_kernel", "mul_f64_kernel");
if let Some(out) = try_launch_binary_f64(a, b, device, ptx, "mul_f64_kernel")? {
return Ok(out);
}
cpu_fallback_binary_f64(a, b, device, |x, y| x * y)
}
#[cfg(feature = "cuda")]
pub fn gpu_div_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
if a.len() != b.len() {
return Err(GpuError::LengthMismatch { a: a.len(), b: b.len() });
}
let ptx = get_f64_ptx(&CACHE, DIV_PTX, "div_kernel", "div_f64_kernel");
if let Some(out) = try_launch_binary_f64(a, b, device, ptx, "div_f64_kernel")? {
return Ok(out);
}
cpu_fallback_binary_f64(a, b, device, |x, y| x / y)
}
#[cfg(feature = "cuda")]
pub fn gpu_neg_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let ptx = get_f64_ptx(&CACHE, NEG_PTX, "neg_kernel", "neg_f64_kernel");
if let Some(out) = try_launch_unary_f64(a, device, ptx, "neg_f64_kernel")? {
return Ok(out);
}
cpu_fallback_unary_f64(a, device, |x| -x)
}
#[cfg(feature = "cuda")]
pub fn gpu_relu_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let ptx = get_f64_ptx(&CACHE, RELU_PTX, "relu_kernel", "relu_f64_kernel");
if let Some(out) = try_launch_unary_f64(a, device, ptx, "relu_f64_kernel")? {
return Ok(out);
}
cpu_fallback_unary_f64(a, device, |x| x.max(0.0))
}
#[cfg(feature = "cuda")]
pub fn gpu_scale_f64(
a: &CudaBuffer<f64>,
scalar: f64,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let n = a.len();
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, SCALE_PTX, "scale_kernel", "scale_f64_kernel");
if let Ok(f) = crate::module_cache::get_or_compile(
ctx, ptx, "scale_f64_kernel", device.ordinal() as u32,
) {
let mut out = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(out.inner_mut())
.arg(&scalar)
.arg(&n_u32)
.launch(cfg)?;
}
return Ok(out);
}
let a_host = gpu_to_cpu(a, device)?;
let result: Vec<f64> = a_host.iter().map(|&x| x * scalar).collect();
cpu_to_gpu(&result, device)
}
#[cfg(feature = "cuda")]
pub fn gpu_exp_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
if let Some(out) = try_launch_unary_f64(a, device, EXP_F64_PTX, "exp_f64_kernel")? {
return Ok(out);
}
cpu_fallback_unary_f64(a, device, |x| x.exp())
}
#[cfg(feature = "cuda")]
pub fn gpu_log_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
if let Some(out) = try_launch_unary_f64(a, device, LOG_F64_PTX, "log_f64_kernel")? {
return Ok(out);
}
cpu_fallback_unary_f64(a, device, |x| x.ln())
}
#[cfg(feature = "cuda")]
pub fn gpu_sqrt_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let ptx = get_f64_ptx(&CACHE, SQRT_PTX, "sqrt_kernel", "sqrt_f64_kernel");
if let Some(out) = try_launch_unary_f64(a, device, ptx, "sqrt_f64_kernel")? {
return Ok(out);
}
cpu_fallback_unary_f64(a, device, |x| x.sqrt())
}
#[cfg(feature = "cuda")]
pub fn gpu_pow_f64(
a: &CudaBuffer<f64>,
exponent: f64,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
let n = a.len();
let ctx = device.context();
let stream = device.stream();
if let Ok(f) = crate::module_cache::get_or_compile(
ctx, POW_F64_PTX, "pow_f64_kernel", device.ordinal() as u32,
) {
let mut out = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(out.inner_mut())
.arg(&exponent)
.arg(&n_u32)
.launch(cfg)?;
}
return Ok(out);
}
let a_host = gpu_to_cpu(a, device)?;
let result: Vec<f64> = a_host.iter().map(|&x| x.powf(exponent)).collect();
cpu_to_gpu(&result, device)
}
#[cfg(feature = "cuda")]
pub fn gpu_abs_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let ptx = get_f64_ptx(&CACHE, ABS_PTX, "abs_kernel", "abs_f64_kernel");
if let Some(out) = try_launch_unary_f64(a, device, ptx, "abs_f64_kernel")? {
return Ok(out);
}
cpu_fallback_unary_f64(a, device, |x| x.abs())
}
#[cfg(feature = "cuda")]
pub fn gpu_sigmoid_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
if let Some(out) = try_launch_unary_f64(a, device, SIGMOID_F64_PTX, "sigmoid_f64_kernel")? {
return Ok(out);
}
cpu_fallback_unary_f64(a, device, |x| 1.0 / (1.0 + (-x).exp()))
}
#[cfg(feature = "cuda")]
pub fn gpu_tanh_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
if let Some(out) = try_launch_unary_f64(a, device, TANH_F64_PTX, "tanh_f64_kernel")? {
return Ok(out);
}
cpu_fallback_unary_f64(a, device, |x| x.tanh())
}
#[cfg(feature = "cuda")]
pub fn gpu_relu_backward_f64(
grad: &CudaBuffer<f64>,
input: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
if grad.len() != input.len() {
return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
}
let ptx = get_f64_ptx(&CACHE, RELU_BACKWARD_PTX, "relu_backward_kernel", "relu_backward_f64_kernel");
if let Some(out) = try_launch_binary_f64(
grad,
input,
device,
ptx,
"relu_backward_f64_kernel",
)? {
return Ok(out);
}
cpu_fallback_binary_f64(grad, input, device, |g, x| if x > 0.0 { g } else { 0.0 })
}
#[cfg(feature = "cuda")]
pub fn gpu_sigmoid_backward_f64(
grad: &CudaBuffer<f64>,
output: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
if grad.len() != output.len() {
return Err(GpuError::LengthMismatch { a: grad.len(), b: output.len() });
}
let ptx = get_f64_ptx(&CACHE, SIGMOID_BACKWARD_PTX, "sigmoid_backward_kernel", "sigmoid_backward_f64_kernel");
if let Some(out) = try_launch_binary_f64(
grad,
output,
device,
ptx,
"sigmoid_backward_f64_kernel",
)? {
return Ok(out);
}
cpu_fallback_binary_f64(grad, output, device, |g, o| g * o * (1.0 - o))
}
#[cfg(feature = "cuda")]
pub fn gpu_tanh_backward_f64(
grad: &CudaBuffer<f64>,
output: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
if grad.len() != output.len() {
return Err(GpuError::LengthMismatch { a: grad.len(), b: output.len() });
}
let ptx = get_f64_ptx(&CACHE, TANH_BACKWARD_PTX, "tanh_backward_kernel", "tanh_backward_f64_kernel");
if let Some(out) = try_launch_binary_f64(
grad,
output,
device,
ptx,
"tanh_backward_f64_kernel",
)? {
return Ok(out);
}
cpu_fallback_binary_f64(grad, output, device, |g, o| g * (1.0 - o * o))
}
#[cfg(feature = "cuda")]
pub fn gpu_broadcast_add_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
let a_str = broadcast_strides(a_shape, out_shape);
let b_str = broadcast_strides(b_shape, out_shape);
let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
let out_numel: usize = out_shape.iter().product();
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let ptx = get_f64_ptx(&CACHE, BROADCAST_ADD_PTX, "broadcast_add_kernel", "broadcast_add_f64_kernel");
if let Some(out) = try_launch_broadcast_binary_f64(
a,
b,
&a_str,
&b_str,
&shape_u32,
out_numel,
device,
ptx,
"broadcast_add_f64_kernel",
)? {
return Ok(out);
}
cpu_fallback_broadcast_binary_f64(a, b, a_shape, b_shape, out_shape, device, |x, y| x + y)
}
#[cfg(feature = "cuda")]
pub fn gpu_broadcast_sub_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
let a_str = broadcast_strides(a_shape, out_shape);
let b_str = broadcast_strides(b_shape, out_shape);
let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
let out_numel: usize = out_shape.iter().product();
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let ptx = get_f64_ptx(&CACHE, BROADCAST_SUB_PTX, "broadcast_sub_kernel", "broadcast_sub_f64_kernel");
if let Some(out) = try_launch_broadcast_binary_f64(
a,
b,
&a_str,
&b_str,
&shape_u32,
out_numel,
device,
ptx,
"broadcast_sub_f64_kernel",
)? {
return Ok(out);
}
cpu_fallback_broadcast_binary_f64(a, b, a_shape, b_shape, out_shape, device, |x, y| x - y)
}
#[cfg(feature = "cuda")]
pub fn gpu_broadcast_mul_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
let a_str = broadcast_strides(a_shape, out_shape);
let b_str = broadcast_strides(b_shape, out_shape);
let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
let out_numel: usize = out_shape.iter().product();
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let ptx = get_f64_ptx(&CACHE, BROADCAST_MUL_PTX, "broadcast_mul_kernel", "broadcast_mul_f64_kernel");
if let Some(out) = try_launch_broadcast_binary_f64(
a,
b,
&a_str,
&b_str,
&shape_u32,
out_numel,
device,
ptx,
"broadcast_mul_f64_kernel",
)? {
return Ok(out);
}
cpu_fallback_broadcast_binary_f64(a, b, a_shape, b_shape, out_shape, device, |x, y| x * y)
}
#[cfg(feature = "cuda")]
pub fn gpu_broadcast_div_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
let a_str = broadcast_strides(a_shape, out_shape);
let b_str = broadcast_strides(b_shape, out_shape);
let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
let out_numel: usize = out_shape.iter().product();
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let ptx = get_f64_ptx(&CACHE, BROADCAST_DIV_PTX, "broadcast_div_kernel", "broadcast_div_f64_kernel");
if let Some(out) = try_launch_broadcast_binary_f64(
a,
b,
&a_str,
&b_str,
&shape_u32,
out_numel,
device,
ptx,
"broadcast_div_f64_kernel",
)? {
return Ok(out);
}
cpu_fallback_broadcast_binary_f64(a, b, a_shape, b_shape, out_shape, device, |x, y| x / y)
}
#[cfg(feature = "cuda")]
pub fn gpu_reduce_sum_f64(
a: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let n = a.len();
if n == 0 {
return cpu_to_gpu(&[0.0f64], device);
}
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, REDUCE_SUM_PTX, "reduce_sum_kernel", "reduce_sum_f64_kernel");
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"reduce_sum_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(a, device)?;
let total: f64 = host.iter().sum();
return cpu_to_gpu(&[total], device);
}
};
const BLOCK: u32 = 256;
let num_blocks = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
let num_blocks = num_blocks.min(1024);
let mut partials = alloc_zeros_f64(num_blocks as usize, device)?;
let n_u32 = n as u32;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks.max(1), 1, 1),
block_dim: (BLOCK, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(partials.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
if num_blocks <= 1 {
return Ok(partials);
}
if num_blocks <= 256 {
let host_partials = gpu_to_cpu(&partials, device)?;
let total: f64 = host_partials.iter().sum();
return cpu_to_gpu(&[total], device);
}
gpu_reduce_sum_f64(&partials, device)
}
#[cfg(feature = "cuda")]
pub fn gpu_sum_axis_f64(
a: &CudaBuffer<f64>,
outer: usize,
axis_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let total_output = outer * inner;
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, SUM_AXIS_PTX, "sum_axis_kernel", "sum_axis_f64_kernel");
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"sum_axis_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(a, device)?;
let mut result = vec![0.0f64; total_output];
for (i, out) in result.iter_mut().enumerate() {
let outer_idx = i / inner;
let inner_idx = i % inner;
let mut sum = 0.0f64;
for k in 0..axis_size {
sum += host[outer_idx * axis_size * inner + k * inner + inner_idx];
}
*out = sum;
}
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f64(total_output, device)?;
let cfg = launch_cfg(total_output)?;
let outer_u32 = outer as u32;
let axis_size_u32 = axis_size as u32;
let inner_u32 = inner as u32;
let total_u32 = total_output as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(out.inner_mut())
.arg(&outer_u32)
.arg(&axis_size_u32)
.arg(&inner_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_reduce_sum_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_sum_axis_f64(_a: &CudaBuffer<f64>, _outer: usize, _axis_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(feature = "cuda")]
pub fn gpu_transpose_2d_f64(
input: &CudaBuffer<f64>,
m: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
validate_device(input, device)?;
let total = m * n;
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, TRANSPOSE_2D_PTX, "transpose_2d_kernel", "transpose_2d_f64_kernel");
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"transpose_2d_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(input, device)?;
let mut out = vec![0.0f64; total];
for i in 0..m {
for j in 0..n {
out[j * m + i] = host[i * n + j];
}
}
return cpu_to_gpu(&out, device);
}
};
let mut out = alloc_zeros_f64(total, device)?;
let cfg = launch_cfg(total)?;
let m_u32 = m as u32;
let n_u32 = n as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&m_u32)
.arg(&n_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_permute_0213_f64(
input: &CudaBuffer<f64>,
d0: usize,
d1: usize,
d2: usize,
d3: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
validate_device(input, device)?;
let total = d0 * d1 * d2 * d3;
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, PERMUTE_0213_PTX, "permute_0213_kernel", "permute_0213_f64_kernel");
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"permute_0213_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(input, device)?;
let mut out = vec![0.0f64; total];
for i0 in 0..d0 {
for i1 in 0..d1 {
for i2 in 0..d2 {
for i3 in 0..d3 {
let in_idx = ((i0 * d1 + i1) * d2 + i2) * d3 + i3;
let out_idx = ((i0 * d2 + i2) * d1 + i1) * d3 + i3;
out[out_idx] = host[in_idx];
}
}
}
}
return cpu_to_gpu(&out, device);
}
};
let mut out = alloc_zeros_f64(total, device)?;
let cfg = launch_cfg(total)?;
let d0_u32 = d0 as u32;
let d1_u32 = d1 as u32;
let d2_u32 = d2 as u32;
let d3_u32 = d3 as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&d0_u32)
.arg(&d1_u32)
.arg(&d2_u32)
.arg(&d3_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_strided_split_f64(
input: &CudaBuffer<f64>,
total_along_axis: usize,
split_offset: usize,
split_size: usize,
inner_size: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
validate_device(input, device)?;
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, STRIDED_SPLIT_PTX, "strided_split_kernel", "strided_split_f64_kernel");
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"strided_split_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(input, device)?;
let mut result = vec![0.0f64; n];
for (i, out) in result.iter_mut().enumerate() {
let outer_idx = i / (split_size * inner_size);
let within = i % (split_size * inner_size);
let src_idx =
outer_idx * total_along_axis * inner_size + split_offset * inner_size + within;
*out = host[src_idx];
}
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(n)?;
let total_ax_u32 = total_along_axis as u32;
let offset_u32 = split_offset as u32;
let split_sz_u32 = split_size as u32;
let inner_u32 = inner_size as u32;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&total_ax_u32)
.arg(&offset_u32)
.arg(&split_sz_u32)
.arg(&inner_u32)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub fn gpu_strided_cat_f64(
input: &CudaBuffer<f64>,
output: &mut CudaBuffer<f64>,
total_along_axis: usize,
cat_offset: usize,
part_size: usize,
inner_size: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
validate_device(input, device)?;
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, STRIDED_CAT_PTX, "strided_cat_kernel", "strided_cat_f64_kernel");
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"strided_cat_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host_in = gpu_to_cpu(input, device)?;
let mut host_out = gpu_to_cpu(output, device)?;
for (i, &val) in host_in.iter().enumerate().take(n) {
let outer_idx = i / (part_size * inner_size);
let within = i % (part_size * inner_size);
let dst_idx =
outer_idx * total_along_axis * inner_size + cat_offset * inner_size + within;
host_out[dst_idx] = val;
}
*output = cpu_to_gpu(&host_out, device)?;
return Ok(());
}
};
let cfg = launch_cfg(n)?;
let total_ax_u32 = total_along_axis as u32;
let offset_u32 = cat_offset as u32;
let part_sz_u32 = part_size as u32;
let inner_u32 = inner_size as u32;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(output.inner_mut())
.arg(&total_ax_u32)
.arg(&offset_u32)
.arg(&part_sz_u32)
.arg(&inner_u32)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gpu_index_select_1d_f64(
input: &CudaBuffer<f64>,
indices: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
validate_device(input, device)?;
let n = indices.len();
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, INDEX_SELECT_1D_PTX, "index_select_1d_kernel", "index_select_1d_f64_kernel");
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"index_select_1d_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let input_host = gpu_to_cpu(input, device)?;
let indices_host = gpu_to_cpu(indices, device)?;
let result: Vec<f64> = indices_host
.iter()
.map(|&idx_f| input_host[idx_f as usize])
.collect();
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(indices.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_scatter_add_1d_f64(
grad_output: &CudaBuffer<f64>,
indices: &CudaBuffer<f32>,
input_len: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
validate_device(grad_output, device)?;
let n = grad_output.len();
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, SCATTER_ADD_1D_PTX, "scatter_add_1d_kernel", "scatter_add_1d_f64_kernel");
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"scatter_add_1d_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let go_host = gpu_to_cpu(grad_output, device)?;
let idx_host = gpu_to_cpu(indices, device)?;
let mut result = vec![0.0f64; input_len];
for (i, &idx_f) in idx_host.iter().enumerate() {
result[idx_f as usize] += go_host[i];
}
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f64(input_len, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(grad_output.inner())
.arg(indices.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_masked_fill_f64(
input: &CudaBuffer<f64>,
mask: &CudaBuffer<u8>,
value: f64,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
validate_device(input, device)?;
let n = input.len();
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, MASKED_FILL_PTX, "masked_fill_kernel", "masked_fill_f64_kernel");
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"masked_fill_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let input_host = gpu_to_cpu(input, device)?;
let mask_host = gpu_to_cpu(mask, device)?;
let result: Vec<f64> = input_host
.iter()
.zip(mask_host.iter())
.map(|(&x, &m)| if m != 0 { value } else { x })
.collect();
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(mask.inner())
.arg(out.inner_mut())
.arg(&value)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_masked_zero_f64(
grad: &CudaBuffer<f64>,
mask: &CudaBuffer<u8>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
validate_device(grad, device)?;
let n = grad.len();
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, MASKED_ZERO_PTX, "masked_zero_kernel", "masked_zero_f64_kernel");
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"masked_zero_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let grad_host = gpu_to_cpu(grad, device)?;
let mask_host = gpu_to_cpu(mask, device)?;
let result: Vec<f64> = grad_host
.iter()
.zip(mask_host.iter())
.map(|(&g, &m)| if m != 0 { 0.0 } else { g })
.collect();
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(grad.inner())
.arg(mask.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_slice_write_f64(
src: &CudaBuffer<f64>,
dst: &mut CudaBuffer<f64>,
n_batch: usize,
d: usize,
max_len: usize,
pos: usize,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let total = n_batch * d;
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, SLICE_WRITE_PTX, "slice_write_kernel", "slice_write_f64_kernel");
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"slice_write_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let src_host = gpu_to_cpu(src, device)?;
let mut dst_host = gpu_to_cpu(dst, device)?;
for b in 0..n_batch {
for di in 0..d {
dst_host[b * max_len * d + pos * d + di] = src_host[b * d + di];
}
}
let new_dst = cpu_to_gpu(&dst_host, device)?;
*dst = new_dst;
return Ok(());
}
};
let cfg = launch_cfg(total)?;
let n_u32 = total as u32;
let d_u32 = d as u32;
let max_len_u32 = max_len as u32;
let pos_u32 = pos as u32;
unsafe {
stream
.launch_builder(&f)
.arg(src.inner())
.arg(dst.inner_mut())
.arg(&n_u32)
.arg(&d_u32)
.arg(&max_len_u32)
.arg(&pos_u32)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gpu_slice_read_f64(
src: &CudaBuffer<f64>,
n_batch: usize,
d: usize,
len: usize,
max_len: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let total = n_batch * len * d;
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, SLICE_READ_PTX, "slice_read_kernel", "slice_read_f64_kernel");
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"slice_read_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(src, device)?;
let mut out = vec![0.0f64; total];
for b in 0..n_batch {
for r in 0..len {
for di in 0..d {
out[b * len * d + r * d + di] = host[b * max_len * d + r * d + di];
}
}
}
return cpu_to_gpu(&out, device);
}
};
let mut out = alloc_zeros_f64(total, device)?;
let cfg = launch_cfg(total)?;
let total_u32 = total as u32;
let d_u32 = d as u32;
let len_u32 = len as u32;
let max_len_u32 = max_len as u32;
unsafe {
stream
.launch_builder(&f)
.arg(src.inner())
.arg(out.inner_mut())
.arg(&total_u32)
.arg(&d_u32)
.arg(&len_u32)
.arg(&max_len_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_embed_lookup_f64(
idx: &CudaBuffer<f32>,
weight: &CudaBuffer<f64>,
d: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, EMBED_LOOKUP_PTX, "embed_lookup_kernel", "embed_lookup_f64_kernel");
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"embed_lookup_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let idx_host = gpu_to_cpu(idx, device)?;
let weight_host = gpu_to_cpu(weight, device)?;
let row = idx_host[0] as usize;
let start = row * d;
let out = weight_host[start..start + d].to_vec();
return cpu_to_gpu(&out, device);
}
};
let mut out = alloc_zeros_f64(d, device)?;
let cfg = launch_cfg(d)?;
let d_u32 = d as u32;
unsafe {
stream
.launch_builder(&f)
.arg(idx.inner())
.arg(weight.inner())
.arg(out.inner_mut())
.arg(&d_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_embed_lookup_batch_f64(
indices: &CudaBuffer<f32>,
weight: &CudaBuffer<f64>,
n: usize,
d: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let total = n * d;
if total == 0 {
return alloc_zeros_f64(0, device);
}
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, EMBED_LOOKUP_BATCH_PTX, "embed_lookup_batch_kernel", "embed_lookup_batch_f64_kernel");
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"embed_lookup_batch_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let idx_host = gpu_to_cpu(indices, device)?;
let weight_host = gpu_to_cpu(weight, device)?;
let mut out = Vec::with_capacity(total);
for &idx_f in &idx_host {
let row = idx_f as usize;
let start = row * d;
out.extend_from_slice(&weight_host[start..start + d]);
}
return cpu_to_gpu(&out, device);
}
};
let mut out = alloc_zeros_f64(total, device)?;
let cfg = launch_cfg(total)?;
let d_u32 = d as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(indices.inner())
.arg(weight.inner())
.arg(out.inner_mut())
.arg(&d_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_scatter_add_rows_f64(
grad_output: &CudaBuffer<f64>,
indices: &CudaBuffer<f32>,
num_embeddings: usize,
d: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let n = indices.len();
let total = n * d;
if total == 0 {
return alloc_zeros_f64(num_embeddings * d, device);
}
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, SCATTER_ADD_ROWS_PTX, "scatter_add_rows_kernel", "scatter_add_rows_f64_kernel");
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"scatter_add_rows_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let go_host = gpu_to_cpu(grad_output, device)?;
let idx_host = gpu_to_cpu(indices, device)?;
let mut result = vec![0.0f64; num_embeddings * d];
for (i, &idx_f) in idx_host.iter().enumerate() {
let row = idx_f as usize;
for j in 0..d {
result[row * d + j] += go_host[i * d + j];
}
}
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f64(num_embeddings * d, device)?;
let cfg = launch_cfg(total)?;
let d_u32 = d as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(grad_output.inner())
.arg(indices.inner())
.arg(out.inner_mut())
.arg(&d_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub fn gpu_fused_adam(
param: &mut CudaBuffer<f32>,
grad: &CudaBuffer<f32>,
exp_avg: &mut CudaBuffer<f32>,
exp_avg_sq: &mut CudaBuffer<f32>,
beta1: f32,
beta2: f32,
lr: f32,
eps: f32,
bc1: f32,
bc2: f32,
weight_decay: f32,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
let n = param.len();
if grad.len() != n || exp_avg.len() != n || exp_avg_sq.len() != n {
return Err(GpuError::LengthMismatch {
a: n,
b: grad.len(),
});
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
FUSED_ADAM_PTX,
"fused_adam_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let mut p_host = gpu_to_cpu(param, device)?;
let g_host = gpu_to_cpu(grad, device)?;
let mut m_host = gpu_to_cpu(exp_avg, device)?;
let mut v_host = gpu_to_cpu(exp_avg_sq, device)?;
for i in 0..n {
let mut g = g_host[i];
if weight_decay > 0.0 {
g += weight_decay * p_host[i];
}
m_host[i] = beta1 * m_host[i] + (1.0 - beta1) * g;
v_host[i] = beta2 * v_host[i] + (1.0 - beta2) * g * g;
let m_hat = m_host[i] / bc1;
let v_hat = v_host[i] / bc2;
p_host[i] -= lr * m_hat / (v_hat.sqrt() + eps);
}
*param = cpu_to_gpu(&p_host, device)?;
*exp_avg = cpu_to_gpu(&m_host, device)?;
*exp_avg_sq = cpu_to_gpu(&v_host, device)?;
return Ok(());
}
};
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(param.inner_mut())
.arg(grad.inner())
.arg(exp_avg.inner_mut())
.arg(exp_avg_sq.inner_mut())
.arg(&beta1)
.arg(&beta2)
.arg(&lr)
.arg(&eps)
.arg(&bc1)
.arg(&bc2)
.arg(&weight_decay)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(())
}
#[cfg(not(feature = "cuda"))]
#[allow(clippy::too_many_arguments)]
pub fn gpu_fused_adam(
_param: &mut CudaBuffer<f32>,
_grad: &CudaBuffer<f32>,
_exp_avg: &mut CudaBuffer<f32>,
_exp_avg_sq: &mut CudaBuffer<f32>,
_beta1: f32,
_beta2: f32,
_lr: f32,
_eps: f32,
_bc1: f32,
_bc2: f32,
_weight_decay: f32,
_device: &GpuDevice,
) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_fused_gru_forward(
input_gates: &CudaBuffer<f32>,
hidden_gates: &CudaBuffer<f32>,
bias_ih: &CudaBuffer<f32>,
bias_hh: &CudaBuffer<f32>,
hx: &CudaBuffer<f32>,
hsz: usize,
device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
use cudarc::driver::PushKernelArg;
let total = hx.len(); let batch = total / hsz;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
FUSED_GRU_FORWARD_PTX,
"fused_gru_forward_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
return Err(GpuError::PtxCompileFailed {
kernel: "fused_gru_forward_kernel",
});
}
};
let mut hy = alloc_zeros_f32(total, device)?;
let mut workspace = alloc_zeros_f32(batch * 5 * hsz, device)?;
let cfg = launch_cfg(total)?;
let hsz_u32 = hsz as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input_gates.inner())
.arg(hidden_gates.inner())
.arg(bias_ih.inner())
.arg(bias_hh.inner())
.arg(hx.inner())
.arg(hy.inner_mut())
.arg(workspace.inner_mut())
.arg(&hsz_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok((hy, workspace))
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_fused_gru_forward(
_input_gates: &CudaBuffer<f32>,
_hidden_gates: &CudaBuffer<f32>,
_bias_ih: &CudaBuffer<f32>,
_bias_hh: &CudaBuffer<f32>,
_hx: &CudaBuffer<f32>,
_hsz: usize,
_device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub fn gpu_maxpool2d(
input: &CudaBuffer<f32>,
batch: usize,
channels: usize,
h_in: usize,
w_in: usize,
kh: usize,
kw: usize,
sh: usize,
sw: usize,
ph: usize,
pw: usize,
device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
use cudarc::driver::PushKernelArg;
let h_out = (h_in + 2 * ph - kh) / sh + 1;
let w_out = (w_in + 2 * pw - kw) / sw + 1;
let total = batch * channels * h_out * w_out;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx, MAXPOOL2D_PTX, "maxpool2d_forward_kernel", device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => return Err(GpuError::PtxCompileFailed { kernel: "maxpool2d_forward_kernel" }),
};
let mut out = alloc_zeros_f32(total, device)?;
let cfg = launch_cfg(total)?;
let (batch_u32, ch_u32) = (batch as u32, channels as u32);
let (h_in_u32, w_in_u32) = (h_in as u32, w_in as u32);
let (h_out_u32, w_out_u32) = (h_out as u32, w_out as u32);
let (kh_u32, kw_u32) = (kh as u32, kw as u32);
let (sh_u32, sw_u32) = (sh as u32, sw as u32);
let (ph_u32, pw_u32) = (ph as u32, pw as u32);
let total_u32 = total as u32;
unsafe {
stream.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&batch_u32).arg(&ch_u32)
.arg(&h_in_u32).arg(&w_in_u32)
.arg(&h_out_u32).arg(&w_out_u32)
.arg(&kh_u32).arg(&kw_u32)
.arg(&sh_u32).arg(&sw_u32)
.arg(&ph_u32).arg(&pw_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok((out, [batch, channels, h_out, w_out]))
}
#[cfg(not(feature = "cuda"))]
#[allow(clippy::too_many_arguments)]
pub fn gpu_maxpool2d(
_input: &CudaBuffer<f32>, _batch: usize, _channels: usize,
_h_in: usize, _w_in: usize, _kh: usize, _kw: usize,
_sh: usize, _sw: usize, _ph: usize, _pw: usize,
_device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub fn gpu_avgpool2d(
input: &CudaBuffer<f32>,
batch: usize,
channels: usize,
h_in: usize,
w_in: usize,
kh: usize,
kw: usize,
sh: usize,
sw: usize,
ph: usize,
pw: usize,
device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
use cudarc::driver::PushKernelArg;
let h_out = (h_in + 2 * ph - kh) / sh + 1;
let w_out = (w_in + 2 * pw - kw) / sw + 1;
let total = batch * channels * h_out * w_out;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx, AVGPOOL2D_PTX, "avgpool2d_forward_kernel", device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => return Err(GpuError::PtxCompileFailed { kernel: "avgpool2d_forward_kernel" }),
};
let mut out = alloc_zeros_f32(total, device)?;
let cfg = launch_cfg(total)?;
let (batch_u32, ch_u32) = (batch as u32, channels as u32);
let (h_in_u32, w_in_u32) = (h_in as u32, w_in as u32);
let (h_out_u32, w_out_u32) = (h_out as u32, w_out as u32);
let (kh_u32, kw_u32) = (kh as u32, kw as u32);
let (sh_u32, sw_u32) = (sh as u32, sw as u32);
let (ph_u32, pw_u32) = (ph as u32, pw as u32);
let total_u32 = total as u32;
unsafe {
stream.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&batch_u32).arg(&ch_u32)
.arg(&h_in_u32).arg(&w_in_u32)
.arg(&h_out_u32).arg(&w_out_u32)
.arg(&kh_u32).arg(&kw_u32)
.arg(&sh_u32).arg(&sw_u32)
.arg(&ph_u32).arg(&pw_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok((out, [batch, channels, h_out, w_out]))
}
#[cfg(not(feature = "cuda"))]
#[allow(clippy::too_many_arguments)]
pub fn gpu_avgpool2d(
_input: &CudaBuffer<f32>, _batch: usize, _channels: usize,
_h_in: usize, _w_in: usize, _kh: usize, _kw: usize,
_sh: usize, _sw: usize, _ph: usize, _pw: usize,
_device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub fn gpu_batchnorm_forward(
_input: &CudaBuffer<f32>,
_weight: &CudaBuffer<f32>,
_bias: &CudaBuffer<f32>,
_running_mean: &mut CudaBuffer<f32>,
_running_var: &mut CudaBuffer<f32>,
_channels: usize,
_spatial: usize,
_eps: f32,
_momentum: f32,
_training: bool,
device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
let ctx = device.context();
let _f = crate::module_cache::get_or_compile(
ctx,
BATCHNORM_FORWARD_PTX,
"batchnorm_forward_kernel",
device.ordinal() as u32,
);
Err(GpuError::ShapeMismatch {
op: "batchnorm_forward",
expected: vec![0],
got: vec![1],
})
}
#[cfg(not(feature = "cuda"))]
#[allow(clippy::too_many_arguments)]
pub fn gpu_batchnorm_forward(
_input: &CudaBuffer<f32>,
_weight: &CudaBuffer<f32>,
_bias: &CudaBuffer<f32>,
_running_mean: &mut CudaBuffer<f32>,
_running_var: &mut CudaBuffer<f32>,
_channels: usize,
_spatial: usize,
_eps: f32,
_momentum: f32,
_training: bool,
_device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_layernorm(
input: &CudaBuffer<f32>,
weight: &CudaBuffer<f32>,
bias: &CudaBuffer<f32>,
rows: usize,
cols: usize,
eps: f32,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
LAYERNORM_PTX,
"layernorm_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
eprintln!("ferrotorch-gpu: LayerNorm PTX compilation failed ({e:?}), CPU fallback");
std::fs::write("/tmp/layernorm_debug.ptx", LAYERNORM_PTX).ok();
eprintln!(
"ferrotorch-gpu: dumped PTX to /tmp/layernorm_debug.ptx ({} bytes)",
LAYERNORM_PTX.len()
);
let h_in = gpu_to_cpu(input, device)?;
let h_w = gpu_to_cpu(weight, device)?;
let h_b = gpu_to_cpu(bias, device)?;
let mut out = vec![0.0f32; rows * cols];
for r in 0..rows {
let base = r * cols;
let slice = &h_in[base..base + cols];
let mean: f32 = slice.iter().sum::<f32>() / cols as f32;
let var: f32 =
slice.iter().map(|&x| (x - mean) * (x - mean)).sum::<f32>() / cols as f32;
let inv_std = 1.0 / (var + eps).sqrt();
for c in 0..cols {
let normed = (slice[c] - mean) * inv_std;
out[base + c] = h_w[c] * normed + h_b[c];
}
}
return cpu_to_gpu(&out, device);
}
};
let mut out = alloc_zeros_f32(rows * cols, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 4,
};
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(weight.inner())
.arg(bias.inner())
.arg(&rows_u32)
.arg(&cols_u32)
.arg(&eps)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_layernorm_backward(
input: &CudaBuffer<f32>,
grad_output: &CudaBuffer<f32>,
weight: &CudaBuffer<f32>,
rows: usize,
cols: usize,
eps: f32,
device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
LAYERNORM_BACKWARD_PTX,
"layernorm_backward_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let h_in = gpu_to_cpu(input, device)?;
let h_go = gpu_to_cpu(grad_output, device)?;
let h_w = gpu_to_cpu(weight, device)?;
let mut grad_input = vec![0.0f32; rows * cols];
let mut grad_weight = vec![0.0f32; cols];
let mut grad_bias = vec![0.0f32; cols];
let n_f = cols as f32;
for r in 0..rows {
let base = r * cols;
let x_slice = &h_in[base..base + cols];
let go_slice = &h_go[base..base + cols];
let mean: f32 = x_slice.iter().sum::<f32>() / n_f;
let var: f32 = x_slice
.iter()
.map(|&x| (x - mean) * (x - mean))
.sum::<f32>()
/ n_f;
let inv_std = 1.0 / (var + eps).sqrt();
let mut sum1 = 0.0f32;
let mut sum2 = 0.0f32;
for c in 0..cols {
let x_hat = (x_slice[c] - mean) * inv_std;
let dl = go_slice[c] * h_w[c];
sum1 += dl;
sum2 += dl * x_hat;
grad_weight[c] += go_slice[c] * x_hat;
grad_bias[c] += go_slice[c];
}
let m1 = sum1 / n_f;
let m2 = sum2 / n_f;
for c in 0..cols {
let x_hat = (x_slice[c] - mean) * inv_std;
let dl = go_slice[c] * h_w[c];
grad_input[base + c] = inv_std * (dl - m1 - x_hat * m2);
}
}
let gi = cpu_to_gpu(&grad_input, device)?;
let gw = cpu_to_gpu(&grad_weight, device)?;
let gb = cpu_to_gpu(&grad_bias, device)?;
return Ok((gi, gw, gb));
}
};
let mut grad_in = alloc_zeros_f32(rows * cols, device)?;
let mut grad_w = alloc_zeros_f32(cols, device)?;
let mut grad_b = alloc_zeros_f32(cols, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 4,
};
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(grad_output.inner())
.arg(weight.inner())
.arg(grad_in.inner_mut())
.arg(grad_w.inner_mut())
.arg(grad_b.inner_mut())
.arg(&rows_u32)
.arg(&cols_u32)
.arg(&eps)
.launch(cfg)?;
}
Ok((grad_in, grad_w, grad_b))
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_layernorm_backward(
_input: &CudaBuffer<f32>,
_grad_output: &CudaBuffer<f32>,
_weight: &CudaBuffer<f32>,
_rows: usize,
_cols: usize,
_eps: f32,
_device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_rmsnorm(
input: &CudaBuffer<f32>,
weight: &CudaBuffer<f32>,
rows: usize,
cols: usize,
eps: f32,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
RMSNORM_PTX,
"rmsnorm_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
eprintln!("ferrotorch-gpu: RMSNorm PTX compilation failed ({e:?}), CPU fallback");
std::fs::write("/tmp/rmsnorm_debug.ptx", RMSNORM_PTX).ok();
eprintln!(
"ferrotorch-gpu: dumped PTX to /tmp/rmsnorm_debug.ptx ({} bytes)",
RMSNORM_PTX.len()
);
let h_in = gpu_to_cpu(input, device)?;
let h_w = gpu_to_cpu(weight, device)?;
let mut out = vec![0.0f32; rows * cols];
for r in 0..rows {
let base = r * cols;
let slice = &h_in[base..base + cols];
let sq_mean: f32 =
slice.iter().map(|&x| x * x).sum::<f32>() / cols as f32;
let inv_rms = 1.0 / (sq_mean + eps).sqrt();
for c in 0..cols {
out[base + c] = slice[c] * inv_rms * h_w[c];
}
}
return cpu_to_gpu(&out, device);
}
};
let mut out = alloc_zeros_f32(rows * cols, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 4,
};
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(weight.inner())
.arg(&rows_u32)
.arg(&cols_u32)
.arg(&eps)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_rmsnorm_backward(
input: &CudaBuffer<f32>,
grad_output: &CudaBuffer<f32>,
weight: &CudaBuffer<f32>,
rows: usize,
cols: usize,
eps: f32,
device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
RMSNORM_BACKWARD_PTX,
"rmsnorm_backward_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let h_in = gpu_to_cpu(input, device)?;
let h_go = gpu_to_cpu(grad_output, device)?;
let h_w = gpu_to_cpu(weight, device)?;
let mut grad_input = vec![0.0f32; rows * cols];
let mut grad_weight = vec![0.0f32; cols];
let n_f = cols as f32;
for r in 0..rows {
let base = r * cols;
let x_slice = &h_in[base..base + cols];
let go_slice = &h_go[base..base + cols];
let sq_mean: f32 =
x_slice.iter().map(|&x| x * x).sum::<f32>() / n_f;
let inv_rms = 1.0 / (sq_mean + eps).sqrt();
let inv_rms3 = inv_rms * inv_rms * inv_rms;
let mut dot = 0.0f32;
for c in 0..cols {
dot += go_slice[c] * x_slice[c] * h_w[c];
grad_weight[c] += go_slice[c] * x_slice[c] * inv_rms;
}
let coeff = dot * inv_rms3 / n_f;
for c in 0..cols {
grad_input[base + c] =
inv_rms * h_w[c] * go_slice[c] - x_slice[c] * coeff;
}
}
let gi = cpu_to_gpu(&grad_input, device)?;
let gw = cpu_to_gpu(&grad_weight, device)?;
return Ok((gi, gw));
}
};
let mut grad_in = alloc_zeros_f32(rows * cols, device)?;
let mut grad_w = alloc_zeros_f32(cols, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 4,
};
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(grad_output.inner())
.arg(weight.inner())
.arg(grad_in.inner_mut())
.arg(grad_w.inner_mut())
.arg(&rows_u32)
.arg(&cols_u32)
.arg(&eps)
.launch(cfg)?;
}
Ok((grad_in, grad_w))
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_rmsnorm(
_input: &CudaBuffer<f32>,
_weight: &CudaBuffer<f32>,
_rows: usize,
_cols: usize,
_eps: f32,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_rmsnorm_backward(
_input: &CudaBuffer<f32>,
_grad_output: &CudaBuffer<f32>,
_weight: &CudaBuffer<f32>,
_rows: usize,
_cols: usize,
_eps: f32,
_device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_add_into(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
validate_binary(a, b, device)?;
if out.len() < a.len() {
return Err(GpuError::ShapeMismatch {
op: "add_into",
expected: vec![a.len()],
got: vec![out.len()],
});
}
if try_launch_binary_into(a, b, out, device, ADD_PTX, "add_kernel")? {
return Ok(());
}
Err(GpuError::PtxCompileFailed {
kernel: "add_kernel",
})
}
#[cfg(feature = "cuda")]
pub fn gpu_mul_into(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
validate_binary(a, b, device)?;
if out.len() < a.len() {
return Err(GpuError::ShapeMismatch {
op: "mul_into",
expected: vec![a.len()],
got: vec![out.len()],
});
}
if try_launch_binary_into(a, b, out, device, MUL_PTX, "mul_kernel")? {
return Ok(());
}
Err(GpuError::PtxCompileFailed {
kernel: "mul_kernel",
})
}
#[cfg(feature = "cuda")]
pub fn gpu_scale_into(
a: &CudaBuffer<f32>,
scalar: f32,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
validate_unary(a, device)?;
let n = a.len();
let ctx = device.context();
let stream = device.stream();
let f = crate::module_cache::get_or_compile(
ctx,
SCALE_PTX,
"scale_kernel",
device.ordinal() as u32,
)
.map_err(|_| GpuError::PtxCompileFailed {
kernel: "scale_kernel",
})?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(out.inner_mut())
.arg(&scalar)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gpu_has_inf_nan(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<bool> {
let n = a.len();
if n == 0 {
return Ok(false);
}
validate_unary(a, device)?;
let host: Vec<f32> = crate::transfer::gpu_to_cpu(a, device)?;
Ok(host.iter().any(|v| !v.is_finite()))
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_has_inf_nan(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<bool> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_gelu_into(
a: &CudaBuffer<f32>,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
validate_unary(a, device)?;
if try_launch_unary_into(a, out, device, GELU_PTX, "gelu_kernel")? {
return Ok(());
}
Err(GpuError::PtxCompileFailed {
kernel: "gelu_kernel",
})
}
#[cfg(feature = "cuda")]
pub fn gpu_embed_lookup_into(
idx: &CudaBuffer<f32>,
weight: &CudaBuffer<f32>,
d: usize,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
let ctx = device.context();
let stream = device.stream();
let f = crate::module_cache::get_or_compile(
ctx,
EMBED_LOOKUP_PTX,
"embed_lookup_kernel",
device.ordinal() as u32,
)
.map_err(|_| GpuError::PtxCompileFailed {
kernel: "embed_lookup_kernel",
})?;
let cfg = launch_cfg(d)?;
let d_u32 = d as u32;
unsafe {
stream
.launch_builder(&f)
.arg(idx.inner())
.arg(weight.inner())
.arg(out.inner_mut())
.arg(&d_u32)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gpu_embed_lookup_batch(
indices: &CudaBuffer<f32>,
weight: &CudaBuffer<f32>,
n: usize,
d: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
let total = n * d;
if total == 0 {
return alloc_zeros_f32(0, device);
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
EMBED_LOOKUP_BATCH_PTX,
"embed_lookup_batch_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let idx_host = gpu_to_cpu(indices, device)?;
let weight_host = gpu_to_cpu(weight, device)?;
let mut out = Vec::with_capacity(total);
for &idx_f in &idx_host {
let row = idx_f as usize;
let start = row * d;
out.extend_from_slice(&weight_host[start..start + d]);
}
return cpu_to_gpu(&out, device);
}
};
let mut out = alloc_zeros_f32(total, device)?;
let cfg = launch_cfg(total)?;
let d_u32 = d as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(indices.inner())
.arg(weight.inner())
.arg(out.inner_mut())
.arg(&d_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_scatter_add_rows(
grad_output: &CudaBuffer<f32>,
indices: &CudaBuffer<f32>,
num_embeddings: usize,
d: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
let n = indices.len();
let total = n * d;
if total == 0 {
return alloc_zeros_f32(num_embeddings * d, device);
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
SCATTER_ADD_ROWS_PTX,
"scatter_add_rows_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let go_host = gpu_to_cpu(grad_output, device)?;
let idx_host = gpu_to_cpu(indices, device)?;
let mut result = vec![0.0f32; num_embeddings * d];
for (i, &idx_f) in idx_host.iter().enumerate() {
let row = idx_f as usize;
for j in 0..d {
result[row * d + j] += go_host[i * d + j];
}
}
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f32(num_embeddings * d, device)?;
let cfg = launch_cfg(total)?;
let d_u32 = d as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(grad_output.inner())
.arg(indices.inner())
.arg(out.inner_mut())
.arg(&d_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_transpose_2d_into(
a: &CudaBuffer<f32>,
m: usize,
n: usize,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
let total = m * n;
let ctx = device.context();
let stream = device.stream();
let f = crate::module_cache::get_or_compile(
ctx,
TRANSPOSE_2D_PTX,
"transpose_2d_kernel",
device.ordinal() as u32,
)
.map_err(|_| GpuError::PtxCompileFailed {
kernel: "transpose_2d_kernel",
})?;
let cfg = launch_cfg(total)?;
let m_u32 = m as u32;
let n_u32 = n as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(out.inner_mut())
.arg(&m_u32)
.arg(&n_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gpu_permute_0213_into(
a: &CudaBuffer<f32>,
d0: usize,
d1: usize,
d2: usize,
d3: usize,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
let total = d0 * d1 * d2 * d3;
let ctx = device.context();
let stream = device.stream();
let f = crate::module_cache::get_or_compile(
ctx,
PERMUTE_0213_PTX,
"permute_0213_kernel",
device.ordinal() as u32,
)
.map_err(|_| GpuError::PtxCompileFailed {
kernel: "permute_0213_kernel",
})?;
let cfg = launch_cfg(total)?;
let (d0u, d1u, d2u, d3u, tu) = (d0 as u32, d1 as u32, d2 as u32, d3 as u32, total as u32);
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(out.inner_mut())
.arg(&d0u)
.arg(&d1u)
.arg(&d2u)
.arg(&d3u)
.arg(&tu)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gpu_softmax_into(
a: &CudaBuffer<f32>,
rows: usize,
cols: usize,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
let ctx = device.context();
let stream = device.stream();
let f = crate::module_cache::get_or_compile(
ctx,
SOFTMAX_PTX,
"softmax_kernel",
device.ordinal() as u32,
)
.map_err(|_| GpuError::PtxCompileFailed {
kernel: "softmax_kernel",
})?;
let block_size = 256u32;
let grid_size = rows as u32;
let cfg = LaunchConfig {
grid_dim: (grid_size, 1, 1),
block_dim: (block_size, 1, 1),
shared_mem_bytes: (cols as u32) * 4,
};
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(out.inner_mut())
.arg(&rows_u32)
.arg(&cols_u32)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub fn gpu_layernorm_into(
input: &CudaBuffer<f32>,
weight: &CudaBuffer<f32>,
bias: &CudaBuffer<f32>,
rows: usize,
cols: usize,
eps: f32,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
let ctx = device.context();
let stream = device.stream();
let f = crate::module_cache::get_or_compile(
ctx,
LAYERNORM_PTX,
"layernorm_kernel",
device.ordinal() as u32,
)
.map_err(|_| GpuError::PtxCompileFailed {
kernel: "layernorm_kernel",
})?;
let block_size = 256u32;
let grid_size = rows as u32;
let cfg = LaunchConfig {
grid_dim: (grid_size, 1, 1),
block_dim: (block_size, 1, 1),
shared_mem_bytes: (cols as u32) * 4,
};
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(weight.inner())
.arg(bias.inner())
.arg(&rows_u32)
.arg(&cols_u32)
.arg(&eps)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gpu_slice_read_into(
src: &CudaBuffer<f32>,
n_batch: usize,
d: usize,
len: usize,
max_len: usize,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
let total = n_batch * len * d;
let ctx = device.context();
let stream = device.stream();
let f = crate::module_cache::get_or_compile(
ctx,
SLICE_READ_PTX,
"slice_read_kernel",
device.ordinal() as u32,
)
.map_err(|_| GpuError::PtxCompileFailed {
kernel: "slice_read_kernel",
})?;
let cfg = launch_cfg(total)?;
let total_u32 = total as u32;
let d_u32 = d as u32;
let len_u32 = len as u32;
let max_len_u32 = max_len as u32;
unsafe {
stream
.launch_builder(&f)
.arg(src.inner())
.arg(out.inner_mut())
.arg(&total_u32)
.arg(&d_u32)
.arg(&len_u32)
.arg(&max_len_u32)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gpu_small_matmul_into(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
m: usize,
k: usize,
n: usize,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
let total = m * n;
let ctx = device.context();
let stream = device.stream();
let f = crate::module_cache::get_or_compile(
ctx,
SMALL_MATMUL_PTX,
"small_matmul_kernel",
device.ordinal() as u32,
)
.map_err(|_| GpuError::PtxCompileFailed {
kernel: "small_matmul_kernel",
})?;
let cfg = launch_cfg(total)?;
let (m_u32, k_u32, n_u32, total_u32) = (m as u32, k as u32, n as u32, total as u32);
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(b.inner())
.arg(out.inner_mut())
.arg(&m_u32)
.arg(&k_u32)
.arg(&n_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gpu_slice_write_indirect(
src: &CudaBuffer<f32>,
dst: &mut CudaBuffer<f32>,
n_batch: usize,
d: usize,
max_len: usize,
pos_ptr: &cudarc::driver::CudaSlice<u32>,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
let total = n_batch * d;
let ctx = device.context();
let stream = device.stream();
let f = crate::module_cache::get_or_compile(
ctx,
SLICE_WRITE_INDIRECT_PTX,
"slice_write_indirect_kernel",
device.ordinal() as u32,
)
.map_err(|_| GpuError::PtxCompileFailed {
kernel: "slice_write_indirect_kernel",
})?;
let cfg = launch_cfg(total)?;
let n_u32 = total as u32;
let d_u32 = d as u32;
let max_len_u32 = max_len as u32;
unsafe {
stream
.launch_builder(&f)
.arg(src.inner())
.arg(dst.inner_mut())
.arg(&n_u32)
.arg(&d_u32)
.arg(&max_len_u32)
.arg(pos_ptr)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gpu_causal_mask_indirect(
total_len_ptr: &cudarc::driver::CudaSlice<u32>,
n_head: usize,
max_pos: usize,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
let total = n_head * max_pos;
let ctx = device.context();
let stream = device.stream();
let f = crate::module_cache::get_or_compile(
ctx,
CAUSAL_MASK_INDIRECT_PTX,
"causal_mask_indirect_kernel",
device.ordinal() as u32,
)
.map_err(|_| GpuError::PtxCompileFailed {
kernel: "causal_mask_indirect_kernel",
})?;
let cfg = launch_cfg(total)?;
let max_pos_u32 = max_pos as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(total_len_ptr)
.arg(out.inner_mut())
.arg(&max_pos_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn precompile_decode_kernels(device: &GpuDevice) -> GpuResult<()> {
let ctx = device.context();
ctx.bind_to_thread()?;
let ord = device.ordinal() as u32;
let compile = |ptx: &'static str, name: &'static str| -> GpuResult<()> {
crate::module_cache::get_or_compile(ctx, ptx, name, ord)
.map(|_| ())
.map_err(GpuError::Driver)
};
compile(ADD_PTX, "add_kernel")?;
compile(MUL_PTX, "mul_kernel")?;
compile(SCALE_PTX, "scale_kernel")?;
compile(GELU_PTX, "gelu_kernel")?;
compile(SOFTMAX_PTX, "softmax_kernel")?;
compile(LAYERNORM_PTX, "layernorm_kernel")?;
compile(PERMUTE_0213_PTX, "permute_0213_kernel")?;
compile(EMBED_LOOKUP_PTX, "embed_lookup_kernel")?;
compile(EMBED_LOOKUP_BATCH_PTX, "embed_lookup_batch_kernel")?;
compile(SCATTER_ADD_ROWS_PTX, "scatter_add_rows_kernel")?;
compile(SMALL_MATMUL_PTX, "small_matmul_kernel")?;
compile(SLICE_WRITE_INDIRECT_PTX, "slice_write_indirect_kernel")?;
compile(CAUSAL_MASK_INDIRECT_PTX, "causal_mask_indirect_kernel")?;
compile(SLICE_READ_PTX, "slice_read_kernel")?;
compile(RELU_BACKWARD_PTX, "relu_backward_kernel")?;
compile(GELU_BACKWARD_PTX, "gelu_backward_kernel")?;
Ok(())
}
#[cfg(not(feature = "cuda"))]
pub fn precompile_decode_kernels(_device: &GpuDevice) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_gelu(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_gelu_tanh(
_input: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_gelu_erf(
_input: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_gelu_backward_tanh(
_grad: &CudaBuffer<f32>,
_input: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_silu(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_silu_backward(
_grad: &CudaBuffer<f32>,
_input: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_elu(
_input: &CudaBuffer<f32>,
_alpha: f32,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_elu_backward(
_grad: &CudaBuffer<f32>,
_input: &CudaBuffer<f32>,
_alpha: f32,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_mish(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_mish_backward(
_grad: &CudaBuffer<f32>,
_input: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_clamp(
_input: &CudaBuffer<f32>,
_min_val: f32,
_max_val: f32,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_div(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_exp(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_log(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_sqrt(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_pow(
_a: &CudaBuffer<f32>,
_exponent: f32,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_abs(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_sigmoid(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_tanh(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_layernorm(
_input: &CudaBuffer<f32>,
_weight: &CudaBuffer<f32>,
_bias: &CudaBuffer<f32>,
_rows: usize,
_cols: usize,
_eps: f32,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_transpose_2d(
_input: &CudaBuffer<f32>,
_m: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_add(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_sub(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_mul(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_neg(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_relu(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_scale(
_a: &CudaBuffer<f32>,
_scalar: f32,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_broadcast_add(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_a_shape: &[usize],
_b_shape: &[usize],
_out_shape: &[usize],
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_broadcast_sub(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_a_shape: &[usize],
_b_shape: &[usize],
_out_shape: &[usize],
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_broadcast_mul(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_a_shape: &[usize],
_b_shape: &[usize],
_out_shape: &[usize],
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_softmax(
_input: &CudaBuffer<f32>,
_rows: usize,
_cols: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_dropout(
_input: &CudaBuffer<f32>,
_threshold: u32,
_scale: f32,
_seed: u32,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_permute_0213(
_input: &CudaBuffer<f32>,
_d0: usize,
_d1: usize,
_d2: usize,
_d3: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_slice_write(
_src: &CudaBuffer<f32>,
_dst: &mut CudaBuffer<f32>,
_n_batch: usize,
_d: usize,
_max_len: usize,
_pos: usize,
_device: &GpuDevice,
) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_slice_read(
_src: &CudaBuffer<f32>,
_n_batch: usize,
_d: usize,
_len: usize,
_max_len: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_embed_lookup(
_idx: &CudaBuffer<f32>,
_weight: &CudaBuffer<f32>,
_d: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_embed_lookup_batch(
_indices: &CudaBuffer<f32>,
_weight: &CudaBuffer<f32>,
_n: usize,
_d: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_scatter_add_rows(
_grad_output: &CudaBuffer<f32>,
_indices: &CudaBuffer<f32>,
_num_embeddings: usize,
_d: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_relu_backward(
_grad: &CudaBuffer<f32>,
_input: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_gelu_backward(
_grad: &CudaBuffer<f32>,
_input: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_index_select_1d(
_input: &CudaBuffer<f32>,
_indices: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_scatter_add_1d(
_grad_output: &CudaBuffer<f32>,
_indices: &CudaBuffer<f32>,
_input_len: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_masked_fill(
_input: &CudaBuffer<f32>,
_mask: &CudaBuffer<f32>,
_value: f32,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_masked_zero(
_grad: &CudaBuffer<f32>,
_mask: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_sigmoid_backward(
_grad: &CudaBuffer<f32>,
_output: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_tanh_backward(
_grad: &CudaBuffer<f32>,
_output: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_softmax_backward(
_grad: &CudaBuffer<f32>,
_output: &CudaBuffer<f32>,
_cols: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_log_softmax(
_input: &CudaBuffer<f32>,
_cols: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_log_softmax_backward(
_grad: &CudaBuffer<f32>,
_output: &CudaBuffer<f32>,
_cols: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_sum_axis(
_a: &CudaBuffer<f32>,
_outer: usize,
_axis_size: usize,
_inner: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_cumsum(
_input: &CudaBuffer<f32>,
_outer: usize,
_dim_size: usize,
_inner: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_cumprod(
_input: &CudaBuffer<f32>,
_outer: usize,
_dim_size: usize,
_inner: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_cummax(
_input: &CudaBuffer<f32>,
_outer: usize,
_dim_size: usize,
_inner: usize,
_device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_cummin(
_input: &CudaBuffer<f32>,
_outer: usize,
_dim_size: usize,
_inner: usize,
_device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_logcumsumexp(
_input: &CudaBuffer<f32>,
_outer: usize,
_dim_size: usize,
_inner: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_strided_split(
_input: &CudaBuffer<f32>,
_total_along_axis: usize,
_split_offset: usize,
_split_size: usize,
_inner_size: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_strided_cat(
_input: &CudaBuffer<f32>,
_output: &mut CudaBuffer<f32>,
_total_along_axis: usize,
_cat_offset: usize,
_part_size: usize,
_inner_size: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub const STRIDED_COPY_MAX_DIMS: usize = 8;
#[cfg(not(feature = "cuda"))]
pub fn gpu_strided_copy(
_input: &CudaBuffer<f32>,
_out_shape: &[usize],
_src_strides: &[isize],
_src_offset: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_strided_copy_f64(
_input: &CudaBuffer<f64>,
_out_shape: &[usize],
_src_strides: &[isize],
_src_offset: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub(crate) fn gpu_f32_to_f16(
input: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
use cudarc::driver::PushKernelArg;
let n = input.len();
if n == 0 {
let empty = device.stream().alloc_zeros::<u16>(0)?;
return Ok(empty);
}
let ctx = device.context();
let stream = device.stream();
let f = crate::module_cache::get_or_compile(
ctx,
F32_TO_F16_PTX,
"f32_to_f16_kernel",
device.ordinal() as u32,
)
.map_err(|_| GpuError::PtxCompileFailed {
kernel: "f32_to_f16_kernel",
})?;
let mut out = stream.alloc_zeros::<u16>(n)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(&mut out)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(not(feature = "cuda"))]
pub(crate) fn gpu_f32_to_f16(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub(crate) fn gpu_f32_to_bf16(
input: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
use cudarc::driver::PushKernelArg;
let n = input.len();
if n == 0 {
let empty = device.stream().alloc_zeros::<u16>(0)?;
return Ok(empty);
}
let ctx = device.context();
let stream = device.stream();
let f = crate::module_cache::get_or_compile(
ctx,
F32_TO_BF16_PTX,
"f32_to_bf16_kernel",
device.ordinal() as u32,
)
.map_err(|_| GpuError::PtxCompileFailed {
kernel: "f32_to_bf16_kernel",
})?;
let mut out = stream.alloc_zeros::<u16>(n)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(&mut out)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(not(feature = "cuda"))]
pub(crate) fn gpu_f32_to_bf16(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_add_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_sub_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_mul_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_div_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_neg_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_relu_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_scale_f64(_a: &CudaBuffer<f64>, _scalar: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_exp_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_log_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_sqrt_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_pow_f64(_a: &CudaBuffer<f64>, _exponent: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_abs_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_sigmoid_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_tanh_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_relu_backward_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_sigmoid_backward_f64(_grad: &CudaBuffer<f64>, _output: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_tanh_backward_f64(_grad: &CudaBuffer<f64>, _output: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_broadcast_add_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _a_shape: &[usize], _b_shape: &[usize], _out_shape: &[usize], _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_broadcast_sub_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _a_shape: &[usize], _b_shape: &[usize], _out_shape: &[usize], _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_broadcast_mul_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _a_shape: &[usize], _b_shape: &[usize], _out_shape: &[usize], _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_broadcast_div_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _a_shape: &[usize], _b_shape: &[usize], _out_shape: &[usize], _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_transpose_2d_f64(_input: &CudaBuffer<f64>, _m: usize, _n: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_permute_0213_f64(_input: &CudaBuffer<f64>, _d0: usize, _d1: usize, _d2: usize, _d3: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_strided_split_f64(_input: &CudaBuffer<f64>, _total_along_axis: usize, _split_offset: usize, _split_size: usize, _inner_size: usize, _n: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_strided_cat_f64(_input: &CudaBuffer<f64>, _output: &mut CudaBuffer<f64>, _total_along_axis: usize, _cat_offset: usize, _part_size: usize, _inner_size: usize, _n: usize, _device: &GpuDevice) -> GpuResult<()> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_index_select_1d_f64(_input: &CudaBuffer<f64>, _indices: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_scatter_add_1d_f64(_grad_output: &CudaBuffer<f64>, _indices: &CudaBuffer<f32>, _input_len: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_masked_fill_f64(_input: &CudaBuffer<f64>, _mask: &CudaBuffer<u8>, _value: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_masked_zero_f64(_grad: &CudaBuffer<f64>, _mask: &CudaBuffer<u8>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_slice_write_f64(_src: &CudaBuffer<f64>, _dst: &mut CudaBuffer<f64>, _n_batch: usize, _d: usize, _max_len: usize, _pos: usize, _device: &GpuDevice) -> GpuResult<()> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_slice_read_f64(_src: &CudaBuffer<f64>, _n_batch: usize, _d: usize, _len: usize, _max_len: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_embed_lookup_f64(_idx: &CudaBuffer<f32>, _weight: &CudaBuffer<f64>, _d: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_embed_lookup_batch_f64(_indices: &CudaBuffer<f32>, _weight: &CudaBuffer<f64>, _n: usize, _d: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_scatter_add_rows_f64(_grad_output: &CudaBuffer<f64>, _indices: &CudaBuffer<f32>, _num_embeddings: usize, _d: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(feature = "cuda")]
pub fn gpu_gelu_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
if let Some(out) = try_launch_unary_f64(input, device, GELU_F64_PTX, "gelu_f64_kernel")? {
return Ok(out);
}
cpu_fallback_unary_f64(input, device, |x| x * (1.0 / (1.0 + (-1.702 * x).exp())))
}
#[cfg(feature = "cuda")]
pub fn gpu_gelu_tanh_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
if let Some(out) = try_launch_unary_f64(input, device, GELU_TANH_F64_PTX, "gelu_tanh_f64_kernel")? {
return Ok(out);
}
cpu_fallback_unary_f64(input, device, |x| {
let inner = (2.0_f64 / std::f64::consts::PI).sqrt() * (x + 0.044715 * x * x * x);
0.5 * x * (1.0 + inner.tanh())
})
}
#[cfg(feature = "cuda")]
pub fn gpu_gelu_erf_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
if let Some(out) = try_launch_unary_f64(input, device, GELU_ERF_F64_PTX, "gelu_erf_f64_kernel")? {
return Ok(out);
}
cpu_fallback_unary_f64(input, device, |x| {
let z = x * std::f64::consts::FRAC_1_SQRT_2;
let az = z.abs();
let t = 1.0 / (1.0 + 0.3275911 * az);
let poly = t * (0.254829592 + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
let erf_abs = 1.0 - poly * (-az * az).exp();
let erf_val = if z >= 0.0 { erf_abs } else { -erf_abs };
x * 0.5 * (1.0 + erf_val)
})
}
#[cfg(feature = "cuda")]
pub fn gpu_gelu_backward_f64(
grad: &CudaBuffer<f64>,
input: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
if grad.len() != input.len() {
return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
}
if let Some(out) = try_launch_binary_f64(grad, input, device, GELU_BACKWARD_F64_PTX, "gelu_backward_f64_kernel")? {
return Ok(out);
}
cpu_fallback_binary_f64(grad, input, device, |g, x| {
let sig = 1.0 / (1.0 + (-1.702 * x).exp());
g * (sig + 1.702 * x * sig * (1.0 - sig))
})
}
#[cfg(feature = "cuda")]
pub fn gpu_gelu_backward_tanh_f64(
grad: &CudaBuffer<f64>,
input: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
if grad.len() != input.len() {
return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
}
if let Some(out) = try_launch_binary_f64(grad, input, device, GELU_BACKWARD_TANH_F64_PTX, "gelu_backward_tanh_f64_kernel")? {
return Ok(out);
}
cpu_fallback_binary_f64(grad, input, device, |g, x| {
let s2pi = (2.0_f64 / std::f64::consts::PI).sqrt();
let c = 0.044715_f64;
let u = s2pi * (x + c * x * x * x);
let t = u.tanh();
let d = 0.5 * (1.0 + t) + 0.5 * x * (1.0 - t * t) * s2pi * (1.0 + 3.0 * c * x * x);
g * d
})
}
#[cfg(feature = "cuda")]
pub fn gpu_gelu_backward_erf_f64(
grad: &CudaBuffer<f64>,
input: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
if grad.len() != input.len() {
return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
}
if let Some(out) = try_launch_binary_f64(grad, input, device, GELU_BACKWARD_ERF_F64_PTX, "gelu_backward_erf_f64_kernel")? {
return Ok(out);
}
cpu_fallback_binary_f64(grad, input, device, |g, x| {
let z = x * std::f64::consts::FRAC_1_SQRT_2;
let az = z.abs();
let t = 1.0 / (1.0 + 0.3275911 * az);
let poly = t * (0.254829592 + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
let erf_abs = 1.0 - poly * (-az * az).exp();
let erf_val = if z >= 0.0 { erf_abs } else { -erf_abs };
let cdf = 0.5 * (1.0 + erf_val);
let pdf = (-x * x / 2.0).exp() / (2.0 * std::f64::consts::PI).sqrt();
g * (cdf + x * pdf)
})
}
#[cfg(feature = "cuda")]
pub fn gpu_silu_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
if let Some(out) = try_launch_unary_f64(input, device, SILU_F64_PTX, "silu_f64_kernel")? {
return Ok(out);
}
cpu_fallback_unary_f64(input, device, |x| x / (1.0 + (-x).exp()))
}
#[cfg(feature = "cuda")]
pub fn gpu_silu_backward_f64(
grad: &CudaBuffer<f64>,
input: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
if grad.len() != input.len() {
return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
}
if let Some(out) = try_launch_binary_f64(grad, input, device, SILU_BACKWARD_F64_PTX, "silu_backward_f64_kernel")? {
return Ok(out);
}
cpu_fallback_binary_f64(grad, input, device, |g, x| {
let sig = 1.0 / (1.0 + (-x).exp());
g * (sig + x * sig * (1.0 - sig))
})
}
#[cfg(feature = "cuda")]
pub fn gpu_elu_f64(
input: &CudaBuffer<f64>,
alpha: f64,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
let n = input.len();
if n == 0 { return cpu_to_gpu(&[], device); }
let ctx = device.context();
let stream = device.stream();
if let Ok(f) = crate::module_cache::get_or_compile(ctx, ELU_F64_PTX, "elu_f64_kernel", device.ordinal() as u32) {
let mut out = alloc_zeros_f64(n, device)?;
let n_u32 = n as u32;
let cfg = launch_cfg(n)?;
unsafe {
stream.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.arg(&alpha)
.launch(cfg)?;
}
return Ok(out);
}
let host = gpu_to_cpu(input, device)?;
let result: Vec<f64> = host.iter().map(|&x| if x > 0.0 { x } else { alpha * (x.exp() - 1.0) }).collect();
cpu_to_gpu(&result, device)
}
#[cfg(feature = "cuda")]
pub fn gpu_elu_backward_f64(
grad: &CudaBuffer<f64>,
input: &CudaBuffer<f64>,
alpha: f64,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
if grad.len() != input.len() {
return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
}
let n = grad.len();
if n == 0 { return cpu_to_gpu(&[], device); }
let ctx = device.context();
let stream = device.stream();
if let Ok(f) = crate::module_cache::get_or_compile(ctx, ELU_BACKWARD_F64_PTX, "elu_backward_f64_kernel", device.ordinal() as u32) {
let mut out = alloc_zeros_f64(n, device)?;
let n_u32 = n as u32;
let cfg = launch_cfg(n)?;
unsafe {
stream.launch_builder(&f)
.arg(grad.inner())
.arg(input.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.arg(&alpha)
.launch(cfg)?;
}
return Ok(out);
}
let g_host = gpu_to_cpu(grad, device)?;
let x_host = gpu_to_cpu(input, device)?;
let result: Vec<f64> = g_host.iter().zip(x_host.iter()).map(|(&g, &x)| if x > 0.0 { g } else { g * alpha * x.exp() }).collect();
cpu_to_gpu(&result, device)
}
#[cfg(feature = "cuda")]
pub fn gpu_mish_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
if let Some(out) = try_launch_unary_f64(input, device, MISH_F64_PTX, "mish_f64_kernel")? {
return Ok(out);
}
cpu_fallback_unary_f64(input, device, |x| x * (1.0_f64 + x.exp()).ln().tanh())
}
#[cfg(feature = "cuda")]
pub fn gpu_mish_backward_f64(
grad: &CudaBuffer<f64>,
input: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
if grad.len() != input.len() {
return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
}
if let Some(out) = try_launch_binary_f64(grad, input, device, MISH_BACKWARD_F64_PTX, "mish_backward_f64_kernel")? {
return Ok(out);
}
cpu_fallback_binary_f64(grad, input, device, |g, x| {
let sp = (1.0_f64 + x.exp()).ln();
let t = sp.tanh();
let sig = 1.0 / (1.0 + (-x).exp());
g * (t + x * sig * (1.0 - t * t))
})
}
#[cfg(feature = "cuda")]
pub fn gpu_clamp_f64(
input: &CudaBuffer<f64>,
min_val: f64,
max_val: f64,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let n = input.len();
if n == 0 { return cpu_to_gpu(&[], device); }
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, CLAMP_PTX, "clamp_kernel", "clamp_f64_kernel");
if let Ok(f) = crate::module_cache::get_or_compile(ctx, ptx, "clamp_f64_kernel", device.ordinal() as u32) {
let mut out = alloc_zeros_f64(n, device)?;
let n_u32 = n as u32;
let cfg = launch_cfg(n)?;
unsafe {
stream.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.arg(&min_val)
.arg(&max_val)
.launch(cfg)?;
}
return Ok(out);
}
let host = gpu_to_cpu(input, device)?;
let result: Vec<f64> = host.iter().map(|&x| x.max(min_val).min(max_val)).collect();
cpu_to_gpu(&result, device)
}
#[cfg(feature = "cuda")]
pub fn gpu_cumsum_f64(
input: &CudaBuffer<f64>,
outer: usize,
dim_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let total = outer * inner;
let n = outer * dim_size * inner;
if n == 0 { return cpu_to_gpu(&[], device); }
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, CUMSUM_PTX, "cumsum_kernel", "cumsum_f64_kernel");
if let Ok(f) = crate::module_cache::get_or_compile(ctx, ptx, "cumsum_f64_kernel", device.ordinal() as u32) {
let mut out = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(total)?;
let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
unsafe {
stream.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&o)
.arg(&d)
.arg(&i)
.arg(&t)
.launch(cfg)?;
}
return Ok(out);
}
Err(GpuError::PtxCompileFailed { kernel: "cumsum_f64_kernel" })
}
#[cfg(feature = "cuda")]
pub fn gpu_cumprod_f64(
input: &CudaBuffer<f64>,
outer: usize,
dim_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let total = outer * inner;
let n = outer * dim_size * inner;
if n == 0 { return cpu_to_gpu(&[], device); }
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, CUMPROD_PTX, "cumprod_kernel", "cumprod_f64_kernel");
if let Ok(f) = crate::module_cache::get_or_compile(ctx, ptx, "cumprod_f64_kernel", device.ordinal() as u32) {
let mut out = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(total)?;
let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
unsafe {
stream.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&o)
.arg(&d)
.arg(&i)
.arg(&t)
.launch(cfg)?;
}
return Ok(out);
}
Err(GpuError::PtxCompileFailed { kernel: "cumprod_f64_kernel" })
}
#[cfg(feature = "cuda")]
pub fn gpu_cummax_f64(
input: &CudaBuffer<f64>,
outer: usize,
dim_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> {
use cudarc::driver::PushKernelArg;
let total = outer * inner;
let n = outer * dim_size * inner;
if n == 0 {
let e: &[f64] = &[];
return Ok((cpu_to_gpu(e, device)?, cpu_to_gpu(e, device)?));
}
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, CUMMAX_PTX, "cummax_kernel", "cummax_f64_kernel");
let f = crate::module_cache::get_or_compile(ctx, ptx, "cummax_f64_kernel", device.ordinal() as u32)
.map_err(|_| GpuError::PtxCompileFailed { kernel: "cummax_f64_kernel" })?;
let mut out = alloc_zeros_f64(n, device)?;
let mut ind = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(total)?;
let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
unsafe {
stream.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(ind.inner_mut())
.arg(&o)
.arg(&d)
.arg(&i)
.arg(&t)
.launch(cfg)?;
}
Ok((out, ind))
}
#[cfg(feature = "cuda")]
pub fn gpu_cummin_f64(
input: &CudaBuffer<f64>,
outer: usize,
dim_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> {
use cudarc::driver::PushKernelArg;
let total = outer * inner;
let n = outer * dim_size * inner;
if n == 0 {
let e: &[f64] = &[];
return Ok((cpu_to_gpu(e, device)?, cpu_to_gpu(e, device)?));
}
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, CUMMIN_PTX, "cummin_kernel", "cummin_f64_kernel");
let f = crate::module_cache::get_or_compile(ctx, ptx, "cummin_f64_kernel", device.ordinal() as u32)
.map_err(|_| GpuError::PtxCompileFailed { kernel: "cummin_f64_kernel" })?;
let mut out = alloc_zeros_f64(n, device)?;
let mut ind = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(total)?;
let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
unsafe {
stream.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(ind.inner_mut())
.arg(&o)
.arg(&d)
.arg(&i)
.arg(&t)
.launch(cfg)?;
}
Ok((out, ind))
}
#[cfg(feature = "cuda")]
pub fn gpu_logcumsumexp_f64(
input: &CudaBuffer<f64>,
outer: usize,
dim_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
let total = outer * inner;
let n = outer * dim_size * inner;
if n == 0 { return cpu_to_gpu(&[], device); }
let ctx = device.context();
let stream = device.stream();
if let Ok(f) = crate::module_cache::get_or_compile(ctx, LOGCUMSUMEXP_F64_PTX, "logcumsumexp_f64_kernel", device.ordinal() as u32) {
let mut out = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(total)?;
let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
unsafe {
stream.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&o)
.arg(&d)
.arg(&i)
.arg(&t)
.launch(cfg)?;
}
return Ok(out);
}
Err(GpuError::PtxCompileFailed { kernel: "logcumsumexp_f64_kernel" })
}
#[cfg(feature = "cuda")]
pub fn gpu_softmax_f64(
input: &CudaBuffer<f64>,
rows: usize,
cols: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
validate_device(input, device)?;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
SOFTMAX_F64_PTX,
"softmax_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(input, device)?;
let mut out = vec![0.0f64; host.len()];
for r in 0..rows {
let base = r * cols;
let mut max_v = f64::NEG_INFINITY;
for c in 0..cols {
max_v = max_v.max(host[base + c]);
}
let mut sum = 0.0f64;
for c in 0..cols {
let e = (host[base + c] - max_v).exp();
out[base + c] = e;
sum += e;
}
let inv = 1.0 / sum;
for c in 0..cols {
out[base + c] *= inv;
}
}
return cpu_to_gpu(&out, device);
}
};
let mut out = alloc_zeros_f64(rows * cols, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 8, };
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&rows_u32)
.arg(&cols_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_softmax_backward_f64(
grad: &CudaBuffer<f64>,
output: &CudaBuffer<f64>,
cols: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
validate_device(grad, device)?;
if grad.len() != output.len() {
return Err(GpuError::LengthMismatch { a: grad.len(), b: output.len() });
}
let total = grad.len();
let rows = total / cols;
let ctx = device.context();
let stream = device.stream();
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let ptx = get_f64_ptx(&CACHE, SOFTMAX_BACKWARD_PTX, "softmax_backward_kernel", "softmax_backward_f64_kernel");
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"softmax_backward_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let grad_host = gpu_to_cpu(grad, device)?;
let output_host = gpu_to_cpu(output, device)?;
let mut result = vec![0.0f64; total];
for r in 0..rows {
let base = r * cols;
let mut dot = 0.0f64;
for c in 0..cols {
dot += grad_host[base + c] * output_host[base + c];
}
for c in 0..cols {
result[base + c] = output_host[base + c] * (grad_host[base + c] - dot);
}
}
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f64(total, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 8,
};
unsafe {
stream
.launch_builder(&f)
.arg(grad.inner())
.arg(output.inner())
.arg(out.inner_mut())
.arg(&rows_u32)
.arg(&cols_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_log_softmax_f64(
input: &CudaBuffer<f64>,
cols: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
validate_device(input, device)?;
let total = input.len();
let rows = total / cols;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
LOG_SOFTMAX_F64_PTX,
"log_softmax_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let host = gpu_to_cpu(input, device)?;
let mut out = vec![0.0f64; total];
for r in 0..rows {
let base = r * cols;
let mut max_v = f64::NEG_INFINITY;
for c in 0..cols {
max_v = max_v.max(host[base + c]);
}
let mut sum_exp = 0.0f64;
for c in 0..cols {
sum_exp += (host[base + c] - max_v).exp();
}
let log_sum_exp = max_v + sum_exp.ln();
for c in 0..cols {
out[base + c] = host[base + c] - log_sum_exp;
}
}
return cpu_to_gpu(&out, device);
}
};
let mut out = alloc_zeros_f64(total, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 8,
};
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&rows_u32)
.arg(&cols_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_log_softmax_backward_f64(
grad: &CudaBuffer<f64>,
output: &CudaBuffer<f64>,
cols: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
validate_device(grad, device)?;
if grad.len() != output.len() {
return Err(GpuError::LengthMismatch { a: grad.len(), b: output.len() });
}
let total = grad.len();
let rows = total / cols;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
LOG_SOFTMAX_BACKWARD_F64_PTX,
"log_softmax_backward_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let grad_host = gpu_to_cpu(grad, device)?;
let output_host = gpu_to_cpu(output, device)?;
let mut result = vec![0.0f64; total];
for r in 0..rows {
let base = r * cols;
let mut sum_grad = 0.0f64;
for c in 0..cols {
sum_grad += grad_host[base + c];
}
for c in 0..cols {
result[base + c] =
grad_host[base + c] - output_host[base + c].exp() * sum_grad;
}
}
return cpu_to_gpu(&result, device);
}
};
let mut out = alloc_zeros_f64(total, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 8,
};
unsafe {
stream
.launch_builder(&f)
.arg(grad.inner())
.arg(output.inner())
.arg(out.inner_mut())
.arg(&rows_u32)
.arg(&cols_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_layernorm_f64(
input: &CudaBuffer<f64>,
weight: &CudaBuffer<f64>,
bias: &CudaBuffer<f64>,
rows: usize,
cols: usize,
eps: f64,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
validate_device(input, device)?;
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, LAYERNORM_PTX, "layernorm_kernel", "layernorm_f64_kernel");
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"layernorm_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let h_in = gpu_to_cpu(input, device)?;
let h_w = gpu_to_cpu(weight, device)?;
let h_b = gpu_to_cpu(bias, device)?;
let mut out = vec![0.0f64; rows * cols];
for r in 0..rows {
let base = r * cols;
let slice = &h_in[base..base + cols];
let mean: f64 = slice.iter().sum::<f64>() / cols as f64;
let var: f64 =
slice.iter().map(|&x| (x - mean) * (x - mean)).sum::<f64>() / cols as f64;
let inv_std = 1.0 / (var + eps).sqrt();
for c in 0..cols {
let normed = (slice[c] - mean) * inv_std;
out[base + c] = h_w[c] * normed + h_b[c];
}
}
return cpu_to_gpu(&out, device);
}
};
let mut out = alloc_zeros_f64(rows * cols, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 8,
};
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(weight.inner())
.arg(bias.inner())
.arg(&rows_u32)
.arg(&cols_u32)
.arg(&eps)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_layernorm_backward_f64(
input: &CudaBuffer<f64>,
grad_output: &CudaBuffer<f64>,
weight: &CudaBuffer<f64>,
rows: usize,
cols: usize,
eps: f64,
device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>, CudaBuffer<f64>)> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
validate_device(input, device)?;
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, LAYERNORM_BACKWARD_PTX, "layernorm_backward_kernel", "layernorm_backward_f64_kernel");
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"layernorm_backward_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let h_in = gpu_to_cpu(input, device)?;
let h_go = gpu_to_cpu(grad_output, device)?;
let h_w = gpu_to_cpu(weight, device)?;
let mut grad_input = vec![0.0f64; rows * cols];
let mut grad_weight = vec![0.0f64; cols];
let mut grad_bias = vec![0.0f64; cols];
let n_f = cols as f64;
for r in 0..rows {
let base = r * cols;
let x_slice = &h_in[base..base + cols];
let go_slice = &h_go[base..base + cols];
let mean: f64 = x_slice.iter().sum::<f64>() / n_f;
let var: f64 = x_slice
.iter()
.map(|&x| (x - mean) * (x - mean))
.sum::<f64>()
/ n_f;
let inv_std = 1.0 / (var + eps).sqrt();
let mut sum1 = 0.0f64;
let mut sum2 = 0.0f64;
for c in 0..cols {
let x_hat = (x_slice[c] - mean) * inv_std;
let dl = go_slice[c] * h_w[c];
sum1 += dl;
sum2 += dl * x_hat;
grad_weight[c] += go_slice[c] * x_hat;
grad_bias[c] += go_slice[c];
}
let m1 = sum1 / n_f;
let m2 = sum2 / n_f;
for c in 0..cols {
let x_hat = (x_slice[c] - mean) * inv_std;
let dl = go_slice[c] * h_w[c];
grad_input[base + c] = inv_std * (dl - m1 - x_hat * m2);
}
}
let gi = cpu_to_gpu(&grad_input, device)?;
let gw = cpu_to_gpu(&grad_weight, device)?;
let gb = cpu_to_gpu(&grad_bias, device)?;
return Ok((gi, gw, gb));
}
};
let mut grad_in = alloc_zeros_f64(rows * cols, device)?;
let mut grad_w = alloc_zeros_f64(cols, device)?;
let mut grad_b = alloc_zeros_f64(cols, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 8,
};
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(grad_output.inner())
.arg(weight.inner())
.arg(grad_in.inner_mut())
.arg(grad_w.inner_mut())
.arg(grad_b.inner_mut())
.arg(&rows_u32)
.arg(&cols_u32)
.arg(&eps)
.launch(cfg)?;
}
Ok((grad_in, grad_w, grad_b))
}
#[cfg(feature = "cuda")]
pub fn gpu_rmsnorm_f64(
input: &CudaBuffer<f64>,
weight: &CudaBuffer<f64>,
rows: usize,
cols: usize,
eps: f64,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
validate_device(input, device)?;
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, RMSNORM_PTX, "rmsnorm_kernel", "rmsnorm_f64_kernel");
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"rmsnorm_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let h_in = gpu_to_cpu(input, device)?;
let h_w = gpu_to_cpu(weight, device)?;
let mut out = vec![0.0f64; rows * cols];
for r in 0..rows {
let base = r * cols;
let slice = &h_in[base..base + cols];
let sq_mean: f64 =
slice.iter().map(|&x| x * x).sum::<f64>() / cols as f64;
let inv_rms = 1.0 / (sq_mean + eps).sqrt();
for c in 0..cols {
out[base + c] = slice[c] * inv_rms * h_w[c];
}
}
return cpu_to_gpu(&out, device);
}
};
let mut out = alloc_zeros_f64(rows * cols, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 8,
};
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(weight.inner())
.arg(&rows_u32)
.arg(&cols_u32)
.arg(&eps)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_rmsnorm_backward_f64(
input: &CudaBuffer<f64>,
grad_output: &CudaBuffer<f64>,
weight: &CudaBuffer<f64>,
rows: usize,
cols: usize,
eps: f64,
device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
validate_device(input, device)?;
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, RMSNORM_BACKWARD_PTX, "rmsnorm_backward_kernel", "rmsnorm_backward_f64_kernel");
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"rmsnorm_backward_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
let h_in = gpu_to_cpu(input, device)?;
let h_go = gpu_to_cpu(grad_output, device)?;
let h_w = gpu_to_cpu(weight, device)?;
let mut grad_input = vec![0.0f64; rows * cols];
let mut grad_weight = vec![0.0f64; cols];
let n_f = cols as f64;
for r in 0..rows {
let base = r * cols;
let x_slice = &h_in[base..base + cols];
let go_slice = &h_go[base..base + cols];
let sq_mean: f64 =
x_slice.iter().map(|&x| x * x).sum::<f64>() / n_f;
let inv_rms = 1.0 / (sq_mean + eps).sqrt();
let inv_rms3 = inv_rms * inv_rms * inv_rms;
let mut dot = 0.0f64;
for c in 0..cols {
dot += go_slice[c] * x_slice[c] * h_w[c];
grad_weight[c] += go_slice[c] * x_slice[c] * inv_rms;
}
let coeff = dot * inv_rms3 / n_f;
for c in 0..cols {
grad_input[base + c] =
inv_rms * h_w[c] * go_slice[c] - x_slice[c] * coeff;
}
}
let gi = cpu_to_gpu(&grad_input, device)?;
let gw = cpu_to_gpu(&grad_weight, device)?;
return Ok((gi, gw));
}
};
let mut grad_in = alloc_zeros_f64(rows * cols, device)?;
let mut grad_w = alloc_zeros_f64(cols, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 8,
};
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(grad_output.inner())
.arg(weight.inner())
.arg(grad_in.inner_mut())
.arg(grad_w.inner_mut())
.arg(&rows_u32)
.arg(&cols_u32)
.arg(&eps)
.launch(cfg)?;
}
Ok((grad_in, grad_w))
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_softmax_f64(_input: &CudaBuffer<f64>, _rows: usize, _cols: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_softmax_backward_f64(_grad: &CudaBuffer<f64>, _output: &CudaBuffer<f64>, _cols: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_log_softmax_f64(_input: &CudaBuffer<f64>, _cols: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_log_softmax_backward_f64(_grad: &CudaBuffer<f64>, _output: &CudaBuffer<f64>, _cols: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_layernorm_f64(_input: &CudaBuffer<f64>, _weight: &CudaBuffer<f64>, _bias: &CudaBuffer<f64>, _rows: usize, _cols: usize, _eps: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_layernorm_backward_f64(_input: &CudaBuffer<f64>, _grad_output: &CudaBuffer<f64>, _weight: &CudaBuffer<f64>, _rows: usize, _cols: usize, _eps: f64, _device: &GpuDevice) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>, CudaBuffer<f64>)> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_rmsnorm_f64(_input: &CudaBuffer<f64>, _weight: &CudaBuffer<f64>, _rows: usize, _cols: usize, _eps: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_rmsnorm_backward_f64(_input: &CudaBuffer<f64>, _grad_output: &CudaBuffer<f64>, _weight: &CudaBuffer<f64>, _rows: usize, _cols: usize, _eps: f64, _device: &GpuDevice) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_gelu_f64(_input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_gelu_tanh_f64(_input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_gelu_erf_f64(_input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_gelu_backward_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_gelu_backward_tanh_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_gelu_backward_erf_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_silu_f64(_input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_silu_backward_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_elu_f64(_input: &CudaBuffer<f64>, _alpha: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_elu_backward_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _alpha: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_mish_f64(_input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_mish_backward_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_clamp_f64(_input: &CudaBuffer<f64>, _min: f64, _max: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_cumsum_f64(_input: &CudaBuffer<f64>, _outer: usize, _dim_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_cumprod_f64(_input: &CudaBuffer<f64>, _outer: usize, _dim_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_cummax_f64(_input: &CudaBuffer<f64>, _outer: usize, _dim_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_cummin_f64(_input: &CudaBuffer<f64>, _outer: usize, _dim_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> { Err(GpuError::NoCudaFeature) }
#[cfg(not(feature = "cuda"))]
pub fn gpu_logcumsumexp_f64(_input: &CudaBuffer<f64>, _outer: usize, _dim_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
#[cfg(test)]
#[cfg(feature = "cuda")]
mod tests {
use super::*;
fn setup(data: &[f32]) -> (GpuDevice, CudaBuffer<f32>) {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let buf = cpu_to_gpu(data, &dev).expect("cpu_to_gpu");
(dev, buf)
}
fn assert_buf_eq(buf: &CudaBuffer<f32>, device: &GpuDevice, expected: &[f32]) {
let host = gpu_to_cpu(buf, device).expect("gpu_to_cpu");
assert_eq!(host.len(), expected.len(), "length mismatch");
for (i, (&got, &exp)) in host.iter().zip(expected.iter()).enumerate() {
assert!(
(got - exp).abs() < 1e-6,
"element {i}: got {got}, expected {exp}",
);
}
}
#[test]
fn add_basic() {
let a_data = vec![1.0f32, 2.0, 3.0, 4.0];
let b_data = vec![10.0f32, 20.0, 30.0, 40.0];
let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x + y).collect();
let (dev, a) = setup(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let out = gpu_add(&a, &b, &dev).expect("gpu_add");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn add_empty() {
let (dev, a) = setup(&[]);
let b = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu b");
let out = gpu_add(&a, &b, &dev).expect("gpu_add empty");
assert_eq!(out.len(), 0);
}
#[test]
fn add_large() {
let n = 100_000;
let a_data: Vec<f32> = (0..n).map(|i| i as f32).collect();
let b_data: Vec<f32> = (0..n).map(|i| (i as f32) * 0.5).collect();
let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x + y).collect();
let (dev, a) = setup(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let out = gpu_add(&a, &b, &dev).expect("gpu_add large");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn add_length_mismatch() {
let (dev, a) = setup(&[1.0, 2.0, 3.0]);
let b = cpu_to_gpu(&[1.0, 2.0], &dev).expect("cpu_to_gpu b");
let err = gpu_add(&a, &b, &dev).unwrap_err();
match err {
GpuError::LengthMismatch { a: 3, b: 2 } => {}
other => panic!("unexpected error: {other}"),
}
}
#[test]
fn sub_basic() {
let a_data = vec![10.0f32, 20.0, 30.0, 40.0];
let b_data = vec![1.0f32, 2.0, 3.0, 4.0];
let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x - y).collect();
let (dev, a) = setup(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let out = gpu_sub(&a, &b, &dev).expect("gpu_sub");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn sub_negative_result() {
let a_data = vec![1.0f32, 2.0];
let b_data = vec![5.0f32, 10.0];
let expected: Vec<f32> = vec![-4.0, -8.0];
let (dev, a) = setup(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let out = gpu_sub(&a, &b, &dev).expect("gpu_sub");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn mul_basic() {
let a_data = vec![2.0f32, 3.0, 4.0, 5.0];
let b_data = vec![10.0f32, 10.0, 10.0, 10.0];
let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x * y).collect();
let (dev, a) = setup(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let out = gpu_mul(&a, &b, &dev).expect("gpu_mul");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn mul_by_zero() {
let a_data = vec![1.0f32, 2.0, 3.0];
let b_data = vec![0.0f32, 0.0, 0.0];
let expected = vec![0.0f32, 0.0, 0.0];
let (dev, a) = setup(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let out = gpu_mul(&a, &b, &dev).expect("gpu_mul");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn neg_basic() {
let a_data = vec![1.0f32, -2.0, 3.0, 0.0, -5.5];
let expected: Vec<f32> = a_data.iter().map(|x| -x).collect();
let (dev, a) = setup(&a_data);
let out = gpu_neg(&a, &dev).expect("gpu_neg");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn neg_double_negation() {
let a_data = vec![1.0f32, -2.0, 3.0];
let (dev, a) = setup(&a_data);
let neg1 = gpu_neg(&a, &dev).expect("gpu_neg 1");
let neg2 = gpu_neg(&neg1, &dev).expect("gpu_neg 2");
assert_buf_eq(&neg2, &dev, &a_data);
}
#[test]
fn relu_basic() {
let a_data = vec![-3.0f32, -1.0, 0.0, 1.0, 3.0];
let expected = vec![0.0f32, 0.0, 0.0, 1.0, 3.0];
let (dev, a) = setup(&a_data);
let out = gpu_relu(&a, &dev).expect("gpu_relu");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn relu_all_negative() {
let a_data = vec![-5.0f32, -0.1, -100.0];
let expected = vec![0.0f32, 0.0, 0.0];
let (dev, a) = setup(&a_data);
let out = gpu_relu(&a, &dev).expect("gpu_relu");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn relu_all_positive() {
let a_data = vec![0.1f32, 1.0, 100.0];
let (dev, a) = setup(&a_data);
let out = gpu_relu(&a, &dev).expect("gpu_relu");
assert_buf_eq(&out, &dev, &a_data);
}
#[test]
fn relu_empty() {
let (dev, a) = setup(&[]);
let out = gpu_relu(&a, &dev).expect("gpu_relu empty");
assert_eq!(out.len(), 0);
}
#[test]
fn small_matmul_2x2() {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let a = cpu_to_gpu(&[1.0f32, 2.0, 3.0, 4.0], &dev).unwrap();
let b = cpu_to_gpu(&[5.0f32, 6.0, 7.0, 8.0], &dev).unwrap();
let c = gpu_small_matmul(&a, &b, 2, 2, 2, &dev).unwrap();
assert_buf_eq(&c, &dev, &[19.0, 22.0, 43.0, 50.0]);
}
#[test]
fn small_matmul_1xk_kxn() {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let a = cpu_to_gpu(&[1.0f32, 2.0, 3.0], &dev).unwrap();
let b = cpu_to_gpu(&[1.0f32, 0.0, 0.0, 1.0, 1.0, 1.0], &dev).unwrap();
let c = gpu_small_matmul(&a, &b, 1, 3, 2, &dev).unwrap();
assert_buf_eq(&c, &dev, &[4.0, 5.0]);
}
#[test]
fn small_matmul_vs_cublas() {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let m = 1;
let k = 64;
let n = 64;
let a_data: Vec<f32> = (0..m * k)
.map(|i| ((i * 7 + 3) % 100) as f32 / 100.0)
.collect();
let b_data: Vec<f32> = (0..k * n)
.map(|i| ((i * 11 + 5) % 100) as f32 / 100.0)
.collect();
let a = cpu_to_gpu(&a_data, &dev).unwrap();
let b = cpu_to_gpu(&b_data, &dev).unwrap();
let c_cublas = crate::blas::gpu_matmul_f32(&a, &b, m, k, n, &dev).unwrap();
let cublas_result = gpu_to_cpu(&c_cublas, &dev).unwrap();
let c_ours = gpu_small_matmul(&a, &b, m, k, n, &dev).unwrap();
let our_result = gpu_to_cpu(&c_ours, &dev).unwrap();
assert_eq!(cublas_result.len(), our_result.len());
for (i, (&cb, &ours)) in cublas_result.iter().zip(our_result.iter()).enumerate() {
assert!(
(cb - ours).abs() < 0.1,
"element {i}: cuBLAS={cb}, ours={ours}, diff={}",
(cb - ours).abs()
);
}
}
#[test]
fn strided_copy_identity_contiguous_2d() {
let data: Vec<f32> = (0..6).map(|i| i as f32).collect();
let (dev, input) = setup(&data);
let out = gpu_strided_copy(&input, &[2, 3], &[3, 1], 0, &dev)
.expect("strided_copy identity");
assert_buf_eq(&out, &dev, &[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
}
#[test]
fn strided_copy_transpose_2d() {
let data: Vec<f32> = (0..6).map(|i| i as f32).collect();
let (dev, input) = setup(&data);
let out = gpu_strided_copy(&input, &[3, 2], &[1, 3], 0, &dev)
.expect("strided_copy transpose");
assert_buf_eq(&out, &dev, &[0.0, 3.0, 1.0, 4.0, 2.0, 5.0]);
}
#[test]
fn strided_copy_sliced_column() {
let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
let (dev, input) = setup(&data);
let out = gpu_strided_copy(&input, &[3], &[4], 2, &dev)
.expect("strided_copy col slice");
assert_buf_eq(&out, &dev, &[2.0, 6.0, 10.0]);
}
#[test]
fn strided_copy_3d_permute() {
let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
let (dev, input) = setup(&data);
let out =
gpu_strided_copy(&input, &[2, 4, 3], &[12, 1, 4], 0, &dev).expect("strided_copy 3d");
let mut expected = vec![0.0f32; 24];
for b in 0..2 {
for i in 0..4 {
for j in 0..3 {
let dst = b * 12 + i * 3 + j;
let src = b * 12 + j * 4 + i;
expected[dst] = data[src];
}
}
}
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn strided_copy_4d_max_rank_supported() {
let shape = [2usize, 3, 2, 2];
let n: usize = shape.iter().product();
let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
let (dev, input) = setup(&data);
let out = gpu_strided_copy(&input, &shape, &[12, 4, 2, 1], 0, &dev)
.expect("strided_copy 4d");
assert_buf_eq(&out, &dev, &data);
}
#[test]
fn strided_copy_rejects_too_many_dims() {
let (dev, input) = setup(&[0.0f32; 16]);
let result = gpu_strided_copy(
&input,
&[1, 1, 1, 1, 1, 1, 1, 1, 16],
&[1; 9],
0,
&dev,
);
assert!(result.is_err());
}
#[test]
fn strided_copy_rejects_shape_stride_length_mismatch() {
let (dev, input) = setup(&[0.0f32; 12]);
let result = gpu_strided_copy(&input, &[3, 4], &[4, 1, 1], 0, &dev);
assert!(result.is_err());
}
#[test]
fn strided_copy_rejects_negative_stride() {
let (dev, input) = setup(&[0.0f32; 6]);
let result = gpu_strided_copy(&input, &[2, 3], &[3, -1], 0, &dev);
assert!(result.is_err());
}
#[test]
fn strided_copy_empty_output() {
let (dev, input) = setup(&[1.0f32, 2.0, 3.0]);
let out = gpu_strided_copy(&input, &[0, 3], &[3, 1], 0, &dev)
.expect("strided_copy empty");
assert_eq!(out.len(), 0);
}
#[test]
fn strided_copy_f64_transpose_matches_f32() {
let data: Vec<f64> = (0..6).map(|i| i as f64).collect();
let dev = GpuDevice::new(0).expect("CUDA device 0");
let input = cpu_to_gpu(&data, &dev).expect("cpu_to_gpu f64");
let out = gpu_strided_copy_f64(&input, &[3, 2], &[1, 3], 0, &dev)
.expect("strided_copy_f64 transpose");
let host = gpu_to_cpu(&out, &dev).expect("gpu_to_cpu f64");
assert_eq!(host, vec![0.0, 3.0, 1.0, 4.0, 2.0, 5.0]);
}
}