#![cfg(feature = "cuda")]
use cudarc::driver::{LaunchConfig, PushKernelArg};
use crate::device::GpuDevice;
use crate::error::{GpuError, GpuResult};
use crate::module_cache::get_or_compile;
const BLOCK_SIZE: u32 = 256;
const MUL_F16_PTX: &str = "\
.version 7.0
.target sm_53
.address_size 64
.visible .entry mul_f16_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %b, %out, %off;
.reg .b16 %a_b16, %b_b16, %out_h;
.reg .f32 %va, %vb, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 1;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.b16 %a_b16, [%a];
ld.global.b16 %b_b16, [%b];
cvt.f32.f16 %va, %a_b16;
cvt.f32.f16 %vb, %b_b16;
mul.f32 %vr, %va, %vb;
cvt.rn.f16.f32 %out_h, %vr;
st.global.b16 [%out], %out_h;
DONE:
ret;
}
";
const ADD_F16_PTX: &str = "\
.version 7.0
.target sm_53
.address_size 64
.visible .entry add_f16_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %b, %out, %off;
.reg .b16 %a_b16, %b_b16, %out_h;
.reg .f32 %va, %vb, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 1;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.b16 %a_b16, [%a];
ld.global.b16 %b_b16, [%b];
cvt.f32.f16 %va, %a_b16;
cvt.f32.f16 %vb, %b_b16;
add.f32 %vr, %va, %vb;
cvt.rn.f16.f32 %out_h, %vr;
st.global.b16 [%out], %out_h;
DONE:
ret;
}
";
const SUB_F16_PTX: &str = "\
.version 7.0
.target sm_53
.address_size 64
.visible .entry sub_f16_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %b, %out, %off;
.reg .b16 %a_b16, %b_b16, %out_h;
.reg .f32 %va, %vb, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 1;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.b16 %a_b16, [%a];
ld.global.b16 %b_b16, [%b];
cvt.f32.f16 %va, %a_b16;
cvt.f32.f16 %vb, %b_b16;
sub.f32 %vr, %va, %vb;
cvt.rn.f16.f32 %out_h, %vr;
st.global.b16 [%out], %out_h;
DONE:
ret;
}
";
const DIV_F16_PTX: &str = "\
.version 7.0
.target sm_53
.address_size 64
.visible .entry div_f16_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %b, %out, %off;
.reg .b16 %a_b16, %b_b16, %out_h;
.reg .f32 %va, %vb, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 1;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.b16 %a_b16, [%a];
ld.global.b16 %b_b16, [%b];
cvt.f32.f16 %va, %a_b16;
cvt.f32.f16 %vb, %b_b16;
div.approx.f32 %vr, %va, %vb;
cvt.rn.f16.f32 %out_h, %vr;
st.global.b16 [%out], %out_h;
DONE:
ret;
}
";
const SILU_F16_PTX: &str = "\
.version 7.0
.target sm_53
.address_size 64
.visible .entry silu_f16_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .b16 %a_b16, %out_h;
.reg .f32 %va, %neg_a, %log2e, %x, %e, %one, %denom, %sig, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 1;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.b16 %a_b16, [%a];
cvt.f32.f16 %va, %a_b16;
neg.f32 %neg_a, %va;
mov.f32 %log2e, 0f3FB8AA3B;
mul.f32 %x, %neg_a, %log2e;
ex2.approx.f32 %e, %x;
mov.f32 %one, 0f3F800000;
add.f32 %denom, %one, %e;
div.approx.f32 %sig, %one, %denom;
mul.f32 %vr, %va, %sig;
cvt.rn.f16.f32 %out_h, %vr;
st.global.b16 [%out], %out_h;
DONE:
ret;
}
";
const RELU_F16_PTX: &str = "\
.version 7.0
.target sm_53
.address_size 64
.visible .entry relu_f16_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .b16 %a_b16, %out_h;
.reg .f32 %va, %zero, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 1;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.b16 %a_b16, [%a];
cvt.f32.f16 %va, %a_b16;
mov.f32 %zero, 0f00000000;
max.f32 %vr, %va, %zero;
cvt.rn.f16.f32 %out_h, %vr;
st.global.b16 [%out], %out_h;
DONE:
ret;
}
";
const GELU_F16_PTX: &str = "\
.version 7.0
.target sm_53
.address_size 64
.visible .entry gelu_f16_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %in, %out, %off;
.reg .b16 %x_b16, %out_h;
.reg .f32 %x, %inv_sqrt2, %arg, %erf_v, %one, %half_c, %sum, %y;
.reg .f32 %ax, %t, %p_const, %one2, %neg_xx, %log2e, %exp_v, %poly, %c_a1, %c_a2, %c_a3, %c_a4, %c_a5;
.reg .pred %p, %signp;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 1;
add.u64 %in, %in, %off;
add.u64 %out, %out, %off;
ld.global.b16 %x_b16, [%in];
cvt.f32.f16 %x, %x_b16;
// arg = x / sqrt(2); 1/sqrt(2) = 0.70710678118f -> 0x3F3504F3
mov.f32 %inv_sqrt2, 0f3F3504F3;
mul.f32 %arg, %x, %inv_sqrt2;
// ax = |arg|
abs.f32 %ax, %arg;
// t = 1 / (1 + 0.3275911 * ax); p = 0.3275911 -> 0x3EA7BA05
mov.f32 %p_const, 0f3EA7BA05;
mov.f32 %one2, 0f3F800000;
fma.rn.f32 %t, %p_const, %ax, %one2;
rcp.approx.f32 %t, %t;
// exp_v = exp(-ax*ax) computed as 2^((-ax*ax) * log2(e)).
mul.f32 %neg_xx, %ax, %ax;
neg.f32 %neg_xx, %neg_xx;
mov.f32 %log2e, 0f3FB8AA3B;
mul.f32 %neg_xx, %neg_xx, %log2e;
ex2.approx.f32 %exp_v, %neg_xx;
// Horner-eval poly = ((((a5*t + a4)*t + a3)*t + a2)*t + a1) * t
mov.f32 %c_a5, 0f3F87DC22;
mov.f32 %c_a4, 0fBFBA00E3;
mov.f32 %c_a3, 0f3FB5F0E3;
mov.f32 %c_a2, 0fBE91A98E;
mov.f32 %c_a1, 0f3E827906;
mul.f32 %poly, %c_a5, %t;
add.f32 %poly, %poly, %c_a4;
mul.f32 %poly, %poly, %t;
add.f32 %poly, %poly, %c_a3;
mul.f32 %poly, %poly, %t;
add.f32 %poly, %poly, %c_a2;
mul.f32 %poly, %poly, %t;
add.f32 %poly, %poly, %c_a1;
mul.f32 %poly, %poly, %t;
// erf(|arg|) = 1 - poly * exp_v
mul.f32 %poly, %poly, %exp_v;
mov.f32 %one, 0f3F800000;
sub.f32 %erf_v, %one, %poly;
// Restore sign.
setp.lt.f32 %signp, %arg, 0f00000000;
@%signp neg.f32 %erf_v, %erf_v;
// y = 0.5 * x * (1 + erf_v)
add.f32 %sum, %one, %erf_v;
mov.f32 %half_c, 0f3F000000;
mul.f32 %y, %half_c, %x;
mul.f32 %y, %y, %sum;
cvt.rn.f16.f32 %out_h, %y;
st.global.b16 [%out], %out_h;
DONE:
ret;
}
";
fn launch_1d(n: usize) -> LaunchConfig {
let grid = ((n as u32).saturating_add(BLOCK_SIZE - 1)) / BLOCK_SIZE;
LaunchConfig {
grid_dim: (grid.max(1), 1, 1),
block_dim: (BLOCK_SIZE, 1, 1),
shared_mem_bytes: 0,
}
}
fn launch_binary(
a: &cudarc::driver::CudaSlice<u16>,
b: &cudarc::driver::CudaSlice<u16>,
device: &GpuDevice,
ptx: &'static str,
kernel_name: &'static str,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
if a.len() != b.len() {
return Err(GpuError::LengthMismatch {
a: a.len(),
b: b.len(),
});
}
let n = a.len();
if n == 0 {
return Ok(device.stream().alloc_zeros::<u16>(0)?);
}
let ctx = device.context();
let stream = device.stream();
let f = get_or_compile(ctx, ptx, kernel_name, device.ordinal() as u32).map_err(|e| {
GpuError::PtxCompileFailed {
kernel: kernel_name,
source: e,
}
})?;
let mut out = stream.alloc_zeros::<u16>(n)?;
let cfg = launch_1d(n);
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a)
.arg(b)
.arg(&mut out)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
fn launch_unary(
a: &cudarc::driver::CudaSlice<u16>,
device: &GpuDevice,
ptx: &'static str,
kernel_name: &'static str,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
let n = a.len();
if n == 0 {
return Ok(device.stream().alloc_zeros::<u16>(0)?);
}
let ctx = device.context();
let stream = device.stream();
let f = get_or_compile(ctx, ptx, kernel_name, device.ordinal() as u32).map_err(|e| {
GpuError::PtxCompileFailed {
kernel: kernel_name,
source: e,
}
})?;
let mut out = stream.alloc_zeros::<u16>(n)?;
let cfg = launch_1d(n);
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a)
.arg(&mut out)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
pub fn gpu_mul_f16(
a: &cudarc::driver::CudaSlice<u16>,
b: &cudarc::driver::CudaSlice<u16>,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
launch_binary(a, b, device, MUL_F16_PTX, "mul_f16_kernel")
}
pub fn gpu_add_f16(
a: &cudarc::driver::CudaSlice<u16>,
b: &cudarc::driver::CudaSlice<u16>,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
launch_binary(a, b, device, ADD_F16_PTX, "add_f16_kernel")
}
pub fn gpu_sub_f16(
a: &cudarc::driver::CudaSlice<u16>,
b: &cudarc::driver::CudaSlice<u16>,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
launch_binary(a, b, device, SUB_F16_PTX, "sub_f16_kernel")
}
pub fn gpu_div_f16(
a: &cudarc::driver::CudaSlice<u16>,
b: &cudarc::driver::CudaSlice<u16>,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
launch_binary(a, b, device, DIV_F16_PTX, "div_f16_kernel")
}
pub fn gpu_silu_f16(
a: &cudarc::driver::CudaSlice<u16>,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
launch_unary(a, device, SILU_F16_PTX, "silu_f16_kernel")
}
pub fn gpu_relu_f16(
a: &cudarc::driver::CudaSlice<u16>,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
launch_unary(a, device, RELU_F16_PTX, "relu_f16_kernel")
}
pub fn gpu_gelu_f16(
a: &cudarc::driver::CudaSlice<u16>,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
launch_unary(a, device, GELU_F16_PTX, "gelu_f16_kernel")
}
const EXP_F16_PTX: &str = "\
.version 7.0
.target sm_53
.address_size 64
.visible .entry exp_f16_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .b16 %a_b16, %out_h;
.reg .f32 %va, %x, %log2e, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 1;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.b16 %a_b16, [%a];
cvt.f32.f16 %va, %a_b16;
// exp(x) = 2^(x * log2(e)); log2(e) = 0x3FB8AA3B as f32 bits.
mov.f32 %log2e, 0f3FB8AA3B;
mul.f32 %x, %va, %log2e;
ex2.approx.f32 %vr, %x;
cvt.rn.f16.f32 %out_h, %vr;
st.global.b16 [%out], %out_h;
DONE:
ret;
}
";
const LOG_F16_PTX: &str = "\
.version 7.0
.target sm_53
.address_size 64
.visible .entry log_f16_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .b16 %a_b16, %out_h;
.reg .f32 %va, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 1;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.b16 %a_b16, [%a];
cvt.f32.f16 %va, %a_b16;
// ln(x) = lg2(x) * ln(2); ln(2) = 0x3F317218.
lg2.approx.f32 %vr, %va;
mul.f32 %vr, %vr, 0f3F317218;
cvt.rn.f16.f32 %out_h, %vr;
st.global.b16 [%out], %out_h;
DONE:
ret;
}
";
const TANH_F16_PTX: &str = "\
.version 7.0
.target sm_53
.address_size 64
.visible .entry tanh_f16_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .b16 %a_b16, %out_h;
.reg .f32 %va, %two_x, %log2e, %arg, %e, %num, %den, %vr, %one;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 1;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.b16 %a_b16, [%a];
cvt.f32.f16 %va, %a_b16;
// e^(2x) via ex2.approx.f32((2x) * log2(e)).
add.f32 %two_x, %va, %va;
mov.f32 %log2e, 0f3FB8AA3B;
mul.f32 %arg, %two_x, %log2e;
ex2.approx.f32 %e, %arg;
mov.f32 %one, 0f3F800000;
sub.f32 %num, %e, %one;
add.f32 %den, %e, %one;
div.approx.f32 %vr, %num, %den;
cvt.rn.f16.f32 %out_h, %vr;
st.global.b16 [%out], %out_h;
DONE:
ret;
}
";
const SQRT_F16_PTX: &str = "\
.version 7.0
.target sm_53
.address_size 64
.visible .entry sqrt_f16_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .b16 %a_b16, %out_h;
.reg .f32 %va, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 1;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.b16 %a_b16, [%a];
cvt.f32.f16 %va, %a_b16;
sqrt.approx.f32 %vr, %va;
cvt.rn.f16.f32 %out_h, %vr;
st.global.b16 [%out], %out_h;
DONE:
ret;
}
";
const SIGMOID_F16_PTX: &str = "\
.version 7.0
.target sm_53
.address_size 64
.visible .entry sigmoid_f16_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .b16 %a_b16, %out_h;
.reg .f32 %va, %neg_a, %log2e, %arg, %e, %one, %denom, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 1;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.b16 %a_b16, [%a];
cvt.f32.f16 %va, %a_b16;
neg.f32 %neg_a, %va;
mov.f32 %log2e, 0f3FB8AA3B;
mul.f32 %arg, %neg_a, %log2e;
ex2.approx.f32 %e, %arg;
mov.f32 %one, 0f3F800000;
add.f32 %denom, %one, %e;
div.approx.f32 %vr, %one, %denom;
cvt.rn.f16.f32 %out_h, %vr;
st.global.b16 [%out], %out_h;
DONE:
ret;
}
";
pub fn gpu_exp_f16(
a: &cudarc::driver::CudaSlice<u16>,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
launch_unary(a, device, EXP_F16_PTX, "exp_f16_kernel")
}
pub fn gpu_log_f16(
a: &cudarc::driver::CudaSlice<u16>,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
launch_unary(a, device, LOG_F16_PTX, "log_f16_kernel")
}
pub fn gpu_tanh_f16(
a: &cudarc::driver::CudaSlice<u16>,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
launch_unary(a, device, TANH_F16_PTX, "tanh_f16_kernel")
}
pub fn gpu_sigmoid_f16(
a: &cudarc::driver::CudaSlice<u16>,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
launch_unary(a, device, SIGMOID_F16_PTX, "sigmoid_f16_kernel")
}
pub fn gpu_sqrt_f16(
a: &cudarc::driver::CudaSlice<u16>,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
launch_unary(a, device, SQRT_F16_PTX, "sqrt_f16_kernel")
}
const SCALE_F16_PTX: &str = "\
.version 7.0
.target sm_53
.address_size 64
.visible .entry scale_f16_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .f32 scale,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .b16 %a_b16, %out_h;
.reg .f32 %va, %scale_r, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.f32 %scale_r, [scale];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 1;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.b16 %a_b16, [%a];
cvt.f32.f16 %va, %a_b16;
mul.f32 %vr, %va, %scale_r;
cvt.rn.f16.f32 %out_h, %vr;
st.global.b16 [%out], %out_h;
DONE:
ret;
}
";
pub fn gpu_scale_f16(
input: &cudarc::driver::CudaSlice<u16>,
scale: f32,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
let n = input.len();
if n == 0 {
return Ok(device.stream().alloc_zeros::<u16>(0)?);
}
let ctx = device.context();
let stream = device.stream();
let f = get_or_compile(
ctx,
SCALE_F16_PTX,
"scale_f16_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "scale_f16_kernel",
source: e,
})?;
let mut out = stream.alloc_zeros::<u16>(n)?;
let cfg = launch_1d(n);
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input)
.arg(&mut out)
.arg(&scale)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
pub fn gpu_neg_f16(
a: &cudarc::driver::CudaSlice<u16>,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
gpu_scale_f16(a, -1.0_f32, device)
}
const SUM_AXIS_F16_PTX: &str = "\
.version 7.0
.target sm_53
.address_size 64
.visible .entry sum_axis_f16_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 outer,
.param .u32 axis_size,
.param .u32 inner,
.param .u32 do_mean
) {
.reg .u32 %r_tid, %bid, %bdim, %outer_r, %axis_r, %inner_r;
.reg .u32 %total_out, %oi, %ii, %k, %a_idx, %do_mean_r;
.reg .u64 %a, %out, %off;
.reg .b16 %a_b16, %out_h;
.reg .f32 %acc, %va, %scale;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %outer_r, [outer];
ld.param.u32 %axis_r, [axis_size];
ld.param.u32 %inner_r, [inner];
ld.param.u32 %do_mean_r, [do_mean];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
mul.lo.u32 %total_out, %outer_r, %inner_r;
setp.ge.u32 %p, %r_tid, %total_out;
@%p bra DONE;
div.u32 %oi, %r_tid, %inner_r;
rem.u32 %ii, %r_tid, %inner_r;
mov.f32 %acc, 0f00000000;
mov.u32 %k, 0;
LOOP:
setp.ge.u32 %p, %k, %axis_r;
@%p bra LOOP_END;
// a_idx = oi * axis_size * inner + k * inner + ii.
mul.lo.u32 %a_idx, %oi, %axis_r;
add.u32 %a_idx, %a_idx, %k;
mul.lo.u32 %a_idx, %a_idx, %inner_r;
add.u32 %a_idx, %a_idx, %ii;
cvt.u64.u32 %off, %a_idx;
shl.b64 %off, %off, 1;
add.u64 %off, %a, %off;
ld.global.b16 %a_b16, [%off];
cvt.f32.f16 %va, %a_b16;
add.f32 %acc, %acc, %va;
add.u32 %k, %k, 1;
bra LOOP;
LOOP_END:
// If do_mean, divide by axis_size.
setp.eq.u32 %p, %do_mean_r, 0;
@%p bra STORE;
cvt.rn.f32.u32 %scale, %axis_r;
div.approx.f32 %acc, %acc, %scale;
STORE:
cvt.rn.f16.f32 %out_h, %acc;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 1;
add.u64 %off, %out, %off;
st.global.b16 [%off], %out_h;
DONE:
ret;
}
";
pub fn gpu_sum_axis_f16(
a: &cudarc::driver::CudaSlice<u16>,
outer: usize,
axis_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
let total = outer.checked_mul(inner).ok_or(GpuError::ShapeMismatch {
op: "sum_axis_f16",
expected: vec![outer, inner],
got: vec![usize::MAX],
})?;
let stream = device.stream();
if total == 0 {
return Ok(stream.alloc_zeros::<u16>(0)?);
}
let ctx = device.context();
let f = get_or_compile(
ctx,
SUM_AXIS_F16_PTX,
"sum_axis_f16_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "sum_axis_f16_kernel",
source: e,
})?;
let mut out = stream.alloc_zeros::<u16>(total)?;
let cfg = launch_1d(total);
let outer_u32 = outer as u32;
let axis_u32 = axis_size as u32;
let inner_u32 = inner as u32;
let do_mean: u32 = 0;
unsafe {
stream
.launch_builder(&f)
.arg(a)
.arg(&mut out)
.arg(&outer_u32)
.arg(&axis_u32)
.arg(&inner_u32)
.arg(&do_mean)
.launch(cfg)?;
}
Ok(out)
}
pub fn gpu_mean_axis_f16(
a: &cudarc::driver::CudaSlice<u16>,
outer: usize,
axis_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
let total = outer.checked_mul(inner).ok_or(GpuError::ShapeMismatch {
op: "mean_axis_f16",
expected: vec![outer, inner],
got: vec![usize::MAX],
})?;
let stream = device.stream();
if total == 0 {
return Ok(stream.alloc_zeros::<u16>(0)?);
}
if axis_size == 0 {
let mut out = stream.alloc_zeros::<u16>(total)?;
let nan_bits = vec![0x7E00_u16; total];
stream.memcpy_htod(&nan_bits, &mut out)?;
return Ok(out);
}
let ctx = device.context();
let f = get_or_compile(
ctx,
SUM_AXIS_F16_PTX,
"sum_axis_f16_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "sum_axis_f16_kernel",
source: e,
})?;
let mut out = stream.alloc_zeros::<u16>(total)?;
let cfg = launch_1d(total);
let outer_u32 = outer as u32;
let axis_u32 = axis_size as u32;
let inner_u32 = inner as u32;
let do_mean: u32 = 1;
unsafe {
stream
.launch_builder(&f)
.arg(a)
.arg(&mut out)
.arg(&outer_u32)
.arg(&axis_u32)
.arg(&inner_u32)
.arg(&do_mean)
.launch(cfg)?;
}
Ok(out)
}
pub fn gpu_sum_f16(
a: &cudarc::driver::CudaSlice<u16>,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
let n = a.len();
if n == 0 {
return Ok(device.stream().alloc_zeros::<u16>(1)?);
}
gpu_sum_axis_f16(a, 1, n, 1, device)
}
pub fn gpu_mean_f16(
a: &cudarc::driver::CudaSlice<u16>,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
let n = a.len();
if n == 0 {
let stream = device.stream();
let mut out = stream.alloc_zeros::<u16>(1)?;
let nan_bits = vec![0x7E00_u16];
stream.memcpy_htod(&nan_bits, &mut out)?;
return Ok(out);
}
let sum = gpu_sum_f16(a, device)?;
let inv_n = 1.0f32 / (n as f32);
gpu_scale_f16(&sum, inv_n, device)
}
const SOFTMAX_F16_PTX: &str = "\
.version 7.0
.target sm_53
.address_size 64
.shared .align 4 .f32 softmax_f16_sdata[256];
.visible .entry softmax_f16_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 rows,
.param .u32 cols
) {
.reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j, %half, %otid;
.reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;
.reg .b16 %x_b16, %out_h;
.reg .f32 %x_f, %tmax, %other, %row_max, %sum, %inv_sum, %e, %scale, %log2e, %y_f;
.reg .pred %p, %lp, %rp, %gp;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %rows_reg, [rows];
ld.param.u32 %cols_reg, [cols];
mov.u64 %sbase, softmax_f16_sdata;
mov.f32 %log2e, 0f3FB8AA3B;
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
setp.ge.u32 %p, %bid, %rows_reg;
@%p bra DONE;
cvt.u64.u32 %row_off, %bid;
cvt.u64.u32 %off, %cols_reg;
mul.lo.u64 %row_off, %row_off, %off;
shl.b64 %row_off, %row_off, 1;
// Pass 1: thread-local max
mov.f32 %tmax, 0fFF800000; // -Inf
mov.u32 %j, %r_tid;
MX:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra MXD;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 1;
add.u64 %off, %in, %off;
add.u64 %off, %off, %row_off;
ld.global.b16 %x_b16, [%off];
cvt.f32.f16 %x_f, %x_b16;
setp.gt.f32 %gp, %x_f, %tmax;
@%gp mov.f32 %tmax, %x_f;
add.u32 %j, %j, %bdim;
bra MX;
MXD:
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %tmax;
bar.sync 0;
mov.u32 %half, %bdim;
MR:
shr.u32 %half, %half, 1;
setp.eq.u32 %rp, %half, 0;
@%rp bra MRD;
setp.ge.u32 %rp, %r_tid, %half;
@%rp bra MRS;
add.u32 %otid, %r_tid, %half;
cvt.u64.u32 %off, %otid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %other, [%saddr];
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %tmax, [%saddr];
setp.gt.f32 %gp, %other, %tmax;
@%gp mov.f32 %tmax, %other;
st.shared.f32 [%saddr], %tmax;
MRS:
bar.sync 0;
bra MR;
MRD:
ld.shared.f32 %row_max, [%sbase];
bar.sync 0;
// Pass 2: thread-local sum of exp(v - row_max)
mov.f32 %sum, 0f00000000;
mov.u32 %j, %r_tid;
SE:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra SED;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 1;
add.u64 %off, %in, %off;
add.u64 %off, %off, %row_off;
ld.global.b16 %x_b16, [%off];
cvt.f32.f16 %x_f, %x_b16;
sub.f32 %x_f, %x_f, %row_max;
mul.f32 %scale, %x_f, %log2e;
ex2.approx.f32 %e, %scale;
add.f32 %sum, %sum, %e;
add.u32 %j, %j, %bdim;
bra SE;
SED:
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %sum;
bar.sync 0;
mov.u32 %half, %bdim;
SER:
shr.u32 %half, %half, 1;
setp.eq.u32 %rp, %half, 0;
@%rp bra SERD;
setp.ge.u32 %rp, %r_tid, %half;
@%rp bra SERS;
add.u32 %otid, %r_tid, %half;
cvt.u64.u32 %off, %otid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %other, [%saddr];
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %sum, [%saddr];
add.f32 %sum, %sum, %other;
st.shared.f32 [%saddr], %sum;
SERS:
bar.sync 0;
bra SER;
SERD:
ld.shared.f32 %sum, [%sbase];
rcp.approx.f32 %inv_sum, %sum;
bar.sync 0;
// Pass 3: write
mov.u32 %j, %r_tid;
WR:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra WRD;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 1;
add.u64 %off, %in, %off;
add.u64 %off, %off, %row_off;
ld.global.b16 %x_b16, [%off];
cvt.f32.f16 %x_f, %x_b16;
sub.f32 %x_f, %x_f, %row_max;
mul.f32 %scale, %x_f, %log2e;
ex2.approx.f32 %e, %scale;
mul.f32 %y_f, %e, %inv_sum;
cvt.rn.f16.f32 %out_h, %y_f;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 1;
add.u64 %off, %out, %off;
add.u64 %off, %off, %row_off;
st.global.b16 [%off], %out_h;
add.u32 %j, %j, %bdim;
bra WR;
WRD:
DONE:
ret;
}
";
pub fn gpu_softmax_f16(
input: &cudarc::driver::CudaSlice<u16>,
rows: usize,
cols: usize,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
if rows == 0 || cols == 0 {
return Ok(device.stream().alloc_zeros::<u16>(rows * cols)?);
}
if input.len() < rows * cols {
return Err(GpuError::ShapeMismatch {
op: "softmax_f16",
expected: vec![rows, cols],
got: vec![input.len()],
});
}
let ctx = device.context();
let stream = device.stream();
let f = get_or_compile(
ctx,
SOFTMAX_F16_PTX,
"softmax_f16_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "softmax_f16_kernel",
source: e,
})?;
let mut out = stream.alloc_zeros::<u16>(rows * cols)?;
let cfg = LaunchConfig {
grid_dim: (rows as u32, 1, 1),
block_dim: (BLOCK_SIZE, 1, 1),
shared_mem_bytes: 0,
};
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input)
.arg(&mut out)
.arg(&rows_u32)
.arg(&cols_u32)
.launch(cfg)?;
}
Ok(out)
}
const LAYERNORM_F16_PTX: &str = "\
.version 7.0
.target sm_53
.address_size 64
.shared .align 4 .f32 layernorm_f16_sdata[256];
.visible .entry layernorm_f16_kernel(
.param .u64 in_ptr,
.param .u64 gamma_ptr,
.param .u64 beta_ptr,
.param .u64 out_ptr,
.param .u32 rows,
.param .u32 cols,
.param .f32 eps
) {
.reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j, %half, %otid;
.reg .u64 %in, %gam, %bet, %out, %row_off, %off, %sbase, %saddr;
.reg .b16 %x_b16, %g_b16, %b_b16, %out_h;
.reg .f32 %x_f, %g_f, %b_f, %sum, %mean, %diff, %var, %eps_r, %inv_std, %normed, %r_f, %other, %n_f;
.reg .pred %p, %lp, %rp;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %gam, [gamma_ptr];
ld.param.u64 %bet, [beta_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %rows_reg, [rows];
ld.param.u32 %cols_reg, [cols];
ld.param.f32 %eps_r, [eps];
mov.u64 %sbase, layernorm_f16_sdata;
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
setp.ge.u32 %p, %bid, %rows_reg;
@%p bra DONE;
cvt.u64.u32 %row_off, %bid;
cvt.u64.u32 %off, %cols_reg;
mul.lo.u64 %row_off, %row_off, %off;
shl.b64 %row_off, %row_off, 1;
cvt.rn.f32.u32 %n_f, %cols_reg;
// Phase 1: sum(x) -> mean
mov.f32 %sum, 0f00000000;
mov.u32 %j, %r_tid;
SM:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra SMD;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 1;
add.u64 %off, %in, %off;
add.u64 %off, %off, %row_off;
ld.global.b16 %x_b16, [%off];
cvt.f32.f16 %x_f, %x_b16;
add.f32 %sum, %sum, %x_f;
add.u32 %j, %j, %bdim;
bra SM;
SMD:
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %sum;
bar.sync 0;
mov.u32 %half, %bdim;
MR:
shr.u32 %half, %half, 1;
setp.eq.u32 %rp, %half, 0;
@%rp bra MRD;
setp.ge.u32 %rp, %r_tid, %half;
@%rp bra MRS;
add.u32 %otid, %r_tid, %half;
cvt.u64.u32 %off, %otid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %other, [%saddr];
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %sum, [%saddr];
add.f32 %sum, %sum, %other;
st.shared.f32 [%saddr], %sum;
MRS:
bar.sync 0;
bra MR;
MRD:
ld.shared.f32 %sum, [%sbase];
div.approx.f32 %mean, %sum, %n_f;
bar.sync 0;
// Phase 2: sum((x - mean)^2) -> var
mov.f32 %var, 0f00000000;
mov.u32 %j, %r_tid;
SV:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra SVD;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 1;
add.u64 %off, %in, %off;
add.u64 %off, %off, %row_off;
ld.global.b16 %x_b16, [%off];
cvt.f32.f16 %x_f, %x_b16;
sub.f32 %diff, %x_f, %mean;
fma.rn.f32 %var, %diff, %diff, %var;
add.u32 %j, %j, %bdim;
bra SV;
SVD:
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %var;
bar.sync 0;
mov.u32 %half, %bdim;
VR:
shr.u32 %half, %half, 1;
setp.eq.u32 %rp, %half, 0;
@%rp bra VRD;
setp.ge.u32 %rp, %r_tid, %half;
@%rp bra VRS;
add.u32 %otid, %r_tid, %half;
cvt.u64.u32 %off, %otid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %other, [%saddr];
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %var, [%saddr];
add.f32 %var, %var, %other;
st.shared.f32 [%saddr], %var;
VRS:
bar.sync 0;
bra VR;
VRD:
ld.shared.f32 %var, [%sbase];
div.approx.f32 %var, %var, %n_f;
add.f32 %var, %var, %eps_r;
sqrt.approx.f32 %inv_std, %var;
rcp.approx.f32 %inv_std, %inv_std;
bar.sync 0;
// Phase 3: out = ((x - mean) * inv_std) * gamma + beta, rounded to f16
mov.u32 %j, %r_tid;
NM:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra NMD;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 1;
add.u64 %off, %in, %off;
add.u64 %off, %off, %row_off;
ld.global.b16 %x_b16, [%off];
cvt.f32.f16 %x_f, %x_b16;
sub.f32 %normed, %x_f, %mean;
mul.f32 %normed, %normed, %inv_std;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 1;
add.u64 %off, %gam, %off;
ld.global.b16 %g_b16, [%off];
cvt.f32.f16 %g_f, %g_b16;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 1;
add.u64 %off, %bet, %off;
ld.global.b16 %b_b16, [%off];
cvt.f32.f16 %b_f, %b_b16;
fma.rn.f32 %r_f, %g_f, %normed, %b_f;
cvt.rn.f16.f32 %out_h, %r_f;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 1;
add.u64 %off, %out, %off;
add.u64 %off, %off, %row_off;
st.global.b16 [%off], %out_h;
add.u32 %j, %j, %bdim;
bra NM;
NMD:
DONE:
ret;
}
";
const RMSNORM_F16_PTX: &str = "\
.version 7.0
.target sm_53
.address_size 64
.shared .align 4 .f32 rmsnorm_f16_sdata[256];
.visible .entry rmsnorm_f16_kernel(
.param .u64 in_ptr,
.param .u64 w_ptr,
.param .u64 out_ptr,
.param .u32 rows,
.param .u32 cols,
.param .f32 eps
) {
.reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j, %half, %otid;
.reg .u64 %in, %w, %out, %row_off, %off, %sbase, %saddr;
.reg .b16 %x_b16, %w_b16, %out_h;
.reg .f32 %x_f, %w_f, %sq_sum, %eps_r, %inv_rms, %mean_sq, %r_f, %other, %n_f;
.reg .pred %p, %lp, %rp;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %w, [w_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %rows_reg, [rows];
ld.param.u32 %cols_reg, [cols];
ld.param.f32 %eps_r, [eps];
mov.u64 %sbase, rmsnorm_f16_sdata;
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
setp.ge.u32 %p, %bid, %rows_reg;
@%p bra DONE;
cvt.u64.u32 %row_off, %bid;
cvt.u64.u32 %off, %cols_reg;
mul.lo.u64 %row_off, %row_off, %off;
shl.b64 %row_off, %row_off, 1;
cvt.rn.f32.u32 %n_f, %cols_reg;
// Phase 1: sum(x^2) in f32
mov.f32 %sq_sum, 0f00000000;
mov.u32 %j, %r_tid;
SS:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra SSD;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 1;
add.u64 %off, %in, %off;
add.u64 %off, %off, %row_off;
ld.global.b16 %x_b16, [%off];
cvt.f32.f16 %x_f, %x_b16;
fma.rn.f32 %sq_sum, %x_f, %x_f, %sq_sum;
add.u32 %j, %j, %bdim;
bra SS;
SSD:
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %sq_sum;
bar.sync 0;
mov.u32 %half, %bdim;
SR:
shr.u32 %half, %half, 1;
setp.eq.u32 %rp, %half, 0;
@%rp bra SRD;
setp.ge.u32 %rp, %r_tid, %half;
@%rp bra SRS;
add.u32 %otid, %r_tid, %half;
cvt.u64.u32 %off, %otid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %other, [%saddr];
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %sq_sum, [%saddr];
add.f32 %sq_sum, %sq_sum, %other;
st.shared.f32 [%saddr], %sq_sum;
SRS:
bar.sync 0;
bra SR;
SRD:
ld.shared.f32 %sq_sum, [%sbase];
div.approx.f32 %mean_sq, %sq_sum, %n_f;
add.f32 %mean_sq, %mean_sq, %eps_r;
sqrt.approx.f32 %inv_rms, %mean_sq;
rcp.approx.f32 %inv_rms, %inv_rms;
bar.sync 0;
// Phase 2: out = x * inv_rms * weight, rounded to f16
mov.u32 %j, %r_tid;
NM:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra NMD;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 1;
add.u64 %off, %in, %off;
add.u64 %off, %off, %row_off;
ld.global.b16 %x_b16, [%off];
cvt.f32.f16 %x_f, %x_b16;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 1;
add.u64 %off, %w, %off;
ld.global.b16 %w_b16, [%off];
cvt.f32.f16 %w_f, %w_b16;
mul.f32 %r_f, %x_f, %inv_rms;
mul.f32 %r_f, %r_f, %w_f;
cvt.rn.f16.f32 %out_h, %r_f;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 1;
add.u64 %off, %out, %off;
add.u64 %off, %off, %row_off;
st.global.b16 [%off], %out_h;
add.u32 %j, %j, %bdim;
bra NM;
NMD:
DONE:
ret;
}
";
pub fn gpu_layernorm_f16(
input: &cudarc::driver::CudaSlice<u16>,
gamma: &cudarc::driver::CudaSlice<u16>,
beta: &cudarc::driver::CudaSlice<u16>,
rows: usize,
cols: usize,
eps: f32,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
if rows == 0 || cols == 0 {
return Ok(device.stream().alloc_zeros::<u16>(rows * cols)?);
}
if input.len() < rows * cols {
return Err(GpuError::ShapeMismatch {
op: "layernorm_f16",
expected: vec![rows, cols],
got: vec![input.len()],
});
}
if gamma.len() < cols || beta.len() < cols {
return Err(GpuError::ShapeMismatch {
op: "layernorm_f16",
expected: vec![cols],
got: vec![gamma.len().min(beta.len())],
});
}
let ctx = device.context();
let stream = device.stream();
let f = get_or_compile(
ctx,
LAYERNORM_F16_PTX,
"layernorm_f16_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "layernorm_f16_kernel",
source: e,
})?;
let mut out = stream.alloc_zeros::<u16>(rows * cols)?;
let cfg = LaunchConfig {
grid_dim: (rows as u32, 1, 1),
block_dim: (BLOCK_SIZE, 1, 1),
shared_mem_bytes: 0,
};
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input)
.arg(gamma)
.arg(beta)
.arg(&mut out)
.arg(&rows_u32)
.arg(&cols_u32)
.arg(&eps)
.launch(cfg)?;
}
Ok(out)
}
pub fn gpu_rmsnorm_f16(
input: &cudarc::driver::CudaSlice<u16>,
weight: &cudarc::driver::CudaSlice<u16>,
rows: usize,
cols: usize,
eps: f32,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
if rows == 0 || cols == 0 {
return Ok(device.stream().alloc_zeros::<u16>(rows * cols)?);
}
if input.len() < rows * cols {
return Err(GpuError::ShapeMismatch {
op: "rmsnorm_f16",
expected: vec![rows, cols],
got: vec![input.len()],
});
}
if weight.len() < cols {
return Err(GpuError::ShapeMismatch {
op: "rmsnorm_f16",
expected: vec![cols],
got: vec![weight.len()],
});
}
let ctx = device.context();
let stream = device.stream();
let f = get_or_compile(
ctx,
RMSNORM_F16_PTX,
"rmsnorm_f16_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "rmsnorm_f16_kernel",
source: e,
})?;
let mut out = stream.alloc_zeros::<u16>(rows * cols)?;
let cfg = LaunchConfig {
grid_dim: (rows as u32, 1, 1),
block_dim: (BLOCK_SIZE, 1, 1),
shared_mem_bytes: 0,
};
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input)
.arg(weight)
.arg(&mut out)
.arg(&rows_u32)
.arg(&cols_u32)
.arg(&eps)
.launch(cfg)?;
}
Ok(out)
}
const BROADCAST_ADD_F16_PTX: &str = "\
.version 7.0
.target sm_53
.address_size 64
.visible .entry broadcast_add_f16_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u64 a_strides_ptr,
.param .u64 b_strides_ptr,
.param .u64 out_shape_ptr,
.param .u32 n,
.param .u32 ndim
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
.reg .u32 %remaining, %a_idx, %b_idx, %d;
.reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
.reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
.reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
.reg .b16 %a_b16, %b_b16, %out_h;
.reg .f32 %va, %vb, %vr;
.reg .pred %p, %loop_p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u64 %a_str, [a_strides_ptr];
ld.param.u64 %b_str, [b_strides_ptr];
ld.param.u64 %oshape, [out_shape_ptr];
ld.param.u32 %n_reg, [n];
ld.param.u32 %ndim_reg, [ndim];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
mov.u32 %remaining, %r_tid;
mov.u32 %a_idx, 0;
mov.u32 %b_idx, 0;
mov.u32 %d, %ndim_reg;
LOOP:
setp.eq.u32 %loop_p, %d, 0;
@%loop_p bra END_LOOP;
sub.u32 %d, %d, 1;
cvt.u64.u32 %d64, %d;
shl.b64 %d64, %d64, 2;
add.u64 %tmp, %oshape, %d64;
ld.global.u32 %shape_d, [%tmp];
add.u64 %tmp, %a_str, %d64;
ld.global.u32 %a_str_d, [%tmp];
add.u64 %tmp, %b_str, %d64;
ld.global.u32 %b_str_d, [%tmp];
rem.u32 %coord, %remaining, %shape_d;
div.u32 %remaining, %remaining, %shape_d;
mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
bra LOOP;
END_LOOP:
cvt.u64.u32 %off_a, %a_idx;
shl.b64 %off_a, %off_a, 1;
add.u64 %off_a, %a, %off_a;
ld.global.b16 %a_b16, [%off_a];
cvt.u64.u32 %off_b, %b_idx;
shl.b64 %off_b, %off_b, 1;
add.u64 %off_b, %b, %off_b;
ld.global.b16 %b_b16, [%off_b];
cvt.f32.f16 %va, %a_b16;
cvt.f32.f16 %vb, %b_b16;
add.f32 %vr, %va, %vb;
cvt.rn.f16.f32 %out_h, %vr;
cvt.u64.u32 %off_out, %r_tid;
shl.b64 %off_out, %off_out, 1;
add.u64 %off_out, %out, %off_out;
st.global.b16 [%off_out], %out_h;
DONE:
ret;
}
";
const BROADCAST_SUB_F16_PTX: &str = "\
.version 7.0
.target sm_53
.address_size 64
.visible .entry broadcast_sub_f16_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u64 a_strides_ptr,
.param .u64 b_strides_ptr,
.param .u64 out_shape_ptr,
.param .u32 n,
.param .u32 ndim
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
.reg .u32 %remaining, %a_idx, %b_idx, %d;
.reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
.reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
.reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
.reg .b16 %a_b16, %b_b16, %out_h;
.reg .f32 %va, %vb, %vr;
.reg .pred %p, %loop_p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u64 %a_str, [a_strides_ptr];
ld.param.u64 %b_str, [b_strides_ptr];
ld.param.u64 %oshape, [out_shape_ptr];
ld.param.u32 %n_reg, [n];
ld.param.u32 %ndim_reg, [ndim];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
mov.u32 %remaining, %r_tid;
mov.u32 %a_idx, 0;
mov.u32 %b_idx, 0;
mov.u32 %d, %ndim_reg;
LOOP:
setp.eq.u32 %loop_p, %d, 0;
@%loop_p bra END_LOOP;
sub.u32 %d, %d, 1;
cvt.u64.u32 %d64, %d;
shl.b64 %d64, %d64, 2;
add.u64 %tmp, %oshape, %d64;
ld.global.u32 %shape_d, [%tmp];
add.u64 %tmp, %a_str, %d64;
ld.global.u32 %a_str_d, [%tmp];
add.u64 %tmp, %b_str, %d64;
ld.global.u32 %b_str_d, [%tmp];
rem.u32 %coord, %remaining, %shape_d;
div.u32 %remaining, %remaining, %shape_d;
mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
bra LOOP;
END_LOOP:
cvt.u64.u32 %off_a, %a_idx;
shl.b64 %off_a, %off_a, 1;
add.u64 %off_a, %a, %off_a;
ld.global.b16 %a_b16, [%off_a];
cvt.u64.u32 %off_b, %b_idx;
shl.b64 %off_b, %off_b, 1;
add.u64 %off_b, %b, %off_b;
ld.global.b16 %b_b16, [%off_b];
cvt.f32.f16 %va, %a_b16;
cvt.f32.f16 %vb, %b_b16;
sub.f32 %vr, %va, %vb;
cvt.rn.f16.f32 %out_h, %vr;
cvt.u64.u32 %off_out, %r_tid;
shl.b64 %off_out, %off_out, 1;
add.u64 %off_out, %out, %off_out;
st.global.b16 [%off_out], %out_h;
DONE:
ret;
}
";
const BROADCAST_MUL_F16_PTX: &str = "\
.version 7.0
.target sm_53
.address_size 64
.visible .entry broadcast_mul_f16_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u64 a_strides_ptr,
.param .u64 b_strides_ptr,
.param .u64 out_shape_ptr,
.param .u32 n,
.param .u32 ndim
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
.reg .u32 %remaining, %a_idx, %b_idx, %d;
.reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
.reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
.reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
.reg .b16 %a_b16, %b_b16, %out_h;
.reg .f32 %va, %vb, %vr;
.reg .pred %p, %loop_p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u64 %a_str, [a_strides_ptr];
ld.param.u64 %b_str, [b_strides_ptr];
ld.param.u64 %oshape, [out_shape_ptr];
ld.param.u32 %n_reg, [n];
ld.param.u32 %ndim_reg, [ndim];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
mov.u32 %remaining, %r_tid;
mov.u32 %a_idx, 0;
mov.u32 %b_idx, 0;
mov.u32 %d, %ndim_reg;
LOOP:
setp.eq.u32 %loop_p, %d, 0;
@%loop_p bra END_LOOP;
sub.u32 %d, %d, 1;
cvt.u64.u32 %d64, %d;
shl.b64 %d64, %d64, 2;
add.u64 %tmp, %oshape, %d64;
ld.global.u32 %shape_d, [%tmp];
add.u64 %tmp, %a_str, %d64;
ld.global.u32 %a_str_d, [%tmp];
add.u64 %tmp, %b_str, %d64;
ld.global.u32 %b_str_d, [%tmp];
rem.u32 %coord, %remaining, %shape_d;
div.u32 %remaining, %remaining, %shape_d;
mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
bra LOOP;
END_LOOP:
cvt.u64.u32 %off_a, %a_idx;
shl.b64 %off_a, %off_a, 1;
add.u64 %off_a, %a, %off_a;
ld.global.b16 %a_b16, [%off_a];
cvt.u64.u32 %off_b, %b_idx;
shl.b64 %off_b, %off_b, 1;
add.u64 %off_b, %b, %off_b;
ld.global.b16 %b_b16, [%off_b];
cvt.f32.f16 %va, %a_b16;
cvt.f32.f16 %vb, %b_b16;
mul.f32 %vr, %va, %vb;
cvt.rn.f16.f32 %out_h, %vr;
cvt.u64.u32 %off_out, %r_tid;
shl.b64 %off_out, %off_out, 1;
add.u64 %off_out, %out, %off_out;
st.global.b16 [%off_out], %out_h;
DONE:
ret;
}
";
const BROADCAST_DIV_F16_PTX: &str = "\
.version 7.0
.target sm_53
.address_size 64
.visible .entry broadcast_div_f16_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u64 a_strides_ptr,
.param .u64 b_strides_ptr,
.param .u64 out_shape_ptr,
.param .u32 n,
.param .u32 ndim
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
.reg .u32 %remaining, %a_idx, %b_idx, %d;
.reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
.reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
.reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
.reg .b16 %a_b16, %b_b16, %out_h;
.reg .f32 %va, %vb, %vr;
.reg .pred %p, %loop_p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u64 %a_str, [a_strides_ptr];
ld.param.u64 %b_str, [b_strides_ptr];
ld.param.u64 %oshape, [out_shape_ptr];
ld.param.u32 %n_reg, [n];
ld.param.u32 %ndim_reg, [ndim];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
mov.u32 %remaining, %r_tid;
mov.u32 %a_idx, 0;
mov.u32 %b_idx, 0;
mov.u32 %d, %ndim_reg;
LOOP:
setp.eq.u32 %loop_p, %d, 0;
@%loop_p bra END_LOOP;
sub.u32 %d, %d, 1;
cvt.u64.u32 %d64, %d;
shl.b64 %d64, %d64, 2;
add.u64 %tmp, %oshape, %d64;
ld.global.u32 %shape_d, [%tmp];
add.u64 %tmp, %a_str, %d64;
ld.global.u32 %a_str_d, [%tmp];
add.u64 %tmp, %b_str, %d64;
ld.global.u32 %b_str_d, [%tmp];
rem.u32 %coord, %remaining, %shape_d;
div.u32 %remaining, %remaining, %shape_d;
mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
bra LOOP;
END_LOOP:
cvt.u64.u32 %off_a, %a_idx;
shl.b64 %off_a, %off_a, 1;
add.u64 %off_a, %a, %off_a;
ld.global.b16 %a_b16, [%off_a];
cvt.u64.u32 %off_b, %b_idx;
shl.b64 %off_b, %off_b, 1;
add.u64 %off_b, %b, %off_b;
ld.global.b16 %b_b16, [%off_b];
cvt.f32.f16 %va, %a_b16;
cvt.f32.f16 %vb, %b_b16;
div.approx.f32 %vr, %va, %vb;
cvt.rn.f16.f32 %out_h, %vr;
cvt.u64.u32 %off_out, %r_tid;
shl.b64 %off_out, %off_out, 1;
add.u64 %off_out, %out, %off_out;
st.global.b16 [%off_out], %out_h;
DONE:
ret;
}
";
fn broadcast_strides_f16(shape: &[usize], out_shape: &[usize]) -> Vec<u32> {
let offset = out_shape.len() - shape.len();
let mut strides = vec![0_u32; out_shape.len()];
if !shape.is_empty() {
let mut row_major = vec![1_usize; shape.len()];
for i in (0..shape.len().saturating_sub(1)).rev() {
row_major[i] = row_major[i + 1] * shape[i + 1];
}
for (i, st) in strides.iter_mut().enumerate() {
if i < offset {
*st = 0;
} else {
let si = i - offset;
if shape[si] == 1 {
*st = 0;
} else {
*st = row_major[si] as u32;
}
}
}
}
strides
}
#[allow(clippy::too_many_arguments)]
fn launch_broadcast_binary_f16(
a: &cudarc::driver::CudaSlice<u16>,
b: &cudarc::driver::CudaSlice<u16>,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
device: &GpuDevice,
ptx: &'static str,
kernel_name: &'static str,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
let out_numel: usize = out_shape.iter().product();
let stream = device.stream();
if out_numel == 0 {
return Ok(stream.alloc_zeros::<u16>(0)?);
}
let a_str = broadcast_strides_f16(a_shape, out_shape);
let b_str = broadcast_strides_f16(b_shape, out_shape);
let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
let ctx = device.context();
let f = get_or_compile(ctx, ptx, kernel_name, device.ordinal() as u32).map_err(|e| {
GpuError::PtxCompileFailed {
kernel: kernel_name,
source: e,
}
})?;
let a_str_buf = crate::transfer::cpu_to_gpu(&a_str, device)?;
let b_str_buf = crate::transfer::cpu_to_gpu(&b_str, device)?;
let shape_buf = crate::transfer::cpu_to_gpu(&shape_u32, device)?;
let mut out = stream.alloc_zeros::<u16>(out_numel)?;
let cfg = launch_1d(out_numel);
let n_u32 = out_numel as u32;
let ndim_u32 = out_shape.len() as u32;
unsafe {
let _kp = &a_str_buf;
let _kp2 = &b_str_buf;
let _kp3 = &shape_buf;
stream
.launch_builder(&f)
.arg(a)
.arg(b)
.arg(&mut out)
.arg(a_str_buf.inner())
.arg(b_str_buf.inner())
.arg(shape_buf.inner())
.arg(&n_u32)
.arg(&ndim_u32)
.launch(cfg)?;
}
Ok(out)
}
pub fn gpu_broadcast_add_f16(
a: &cudarc::driver::CudaSlice<u16>,
b: &cudarc::driver::CudaSlice<u16>,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
launch_broadcast_binary_f16(
a,
b,
a_shape,
b_shape,
out_shape,
device,
BROADCAST_ADD_F16_PTX,
"broadcast_add_f16_kernel",
)
}
pub fn gpu_broadcast_sub_f16(
a: &cudarc::driver::CudaSlice<u16>,
b: &cudarc::driver::CudaSlice<u16>,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
launch_broadcast_binary_f16(
a,
b,
a_shape,
b_shape,
out_shape,
device,
BROADCAST_SUB_F16_PTX,
"broadcast_sub_f16_kernel",
)
}
pub fn gpu_broadcast_mul_f16(
a: &cudarc::driver::CudaSlice<u16>,
b: &cudarc::driver::CudaSlice<u16>,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
launch_broadcast_binary_f16(
a,
b,
a_shape,
b_shape,
out_shape,
device,
BROADCAST_MUL_F16_PTX,
"broadcast_mul_f16_kernel",
)
}
pub fn gpu_broadcast_div_f16(
a: &cudarc::driver::CudaSlice<u16>,
b: &cudarc::driver::CudaSlice<u16>,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
launch_broadcast_binary_f16(
a,
b,
a_shape,
b_shape,
out_shape,
device,
BROADCAST_DIV_F16_PTX,
"broadcast_div_f16_kernel",
)
}