#[cfg(feature = "cuda")]
use cudarc::driver::LaunchConfig;
use crate::buffer::CudaBuffer;
use crate::device::GpuDevice;
use crate::error::{GpuError, GpuResult};
#[cfg(feature = "cuda")]
use crate::transfer::{alloc_zeros, cpu_to_gpu, gpu_to_cpu};
#[cfg(feature = "cuda")]
pub(crate) const ADD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry add_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %b, %out, %off;
.reg .f32 %va, %vb, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %tid, %tid.x;
mad.lo.u32 %tid, %bid, %bdim, %tid;
setp.ge.u32 %p, %tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
ld.global.f32 %vb, [%b];
add.f32 %vr, %va, %vb;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SUB_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry sub_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %b, %out, %off;
.reg .f32 %va, %vb, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %tid, %tid.x;
mad.lo.u32 %tid, %bid, %bdim, %tid;
setp.ge.u32 %p, %tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
ld.global.f32 %vb, [%b];
sub.f32 %vr, %va, %vb;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const MUL_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry mul_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %b, %out, %off;
.reg .f32 %va, %vb, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %tid, %tid.x;
mad.lo.u32 %tid, %bid, %bdim, %tid;
setp.ge.u32 %p, %tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
ld.global.f32 %vb, [%b];
mul.f32 %vr, %va, %vb;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const NEG_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry neg_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %va, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %tid, %tid.x;
mad.lo.u32 %tid, %bid, %bdim, %tid;
setp.ge.u32 %p, %tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
neg.f32 %vr, %va;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const RELU_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry relu_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %va, %vr, %zero;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %tid, %tid.x;
mad.lo.u32 %tid, %bid, %bdim, %tid;
setp.ge.u32 %p, %tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
mov.f32 %zero, 0f00000000;
max.f32 %vr, %va, %zero;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
fn launch_cfg(n: usize) -> GpuResult<LaunchConfig> {
if n > u32::MAX as usize {
return Err(GpuError::ShapeMismatch {
op: "kernel_launch",
expected: vec![u32::MAX as usize],
got: vec![n],
});
}
const BLOCK: u32 = 256;
let grid = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
Ok(LaunchConfig {
grid_dim: (grid.max(1), 1, 1),
block_dim: (BLOCK, 1, 1),
shared_mem_bytes: 0,
})
}
#[cfg(feature = "cuda")]
fn validate_binary(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
if a.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: a.device_ordinal(),
got: device.ordinal(),
});
}
if b.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: b.device_ordinal(),
got: device.ordinal(),
});
}
if a.len() != b.len() {
return Err(GpuError::LengthMismatch {
a: a.len(),
b: b.len(),
});
}
Ok(())
}
#[cfg(feature = "cuda")]
fn validate_unary(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<()> {
if a.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: a.device_ordinal(),
got: device.ordinal(),
});
}
Ok(())
}
#[cfg(feature = "cuda")]
fn try_launch_binary(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
device: &GpuDevice,
ptx_src: &'static str,
kernel_name: &'static str,
) -> GpuResult<Option<CudaBuffer<f32>>> {
use cudarc::driver::PushKernelArg;
let n = a.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(ctx, ptx_src, kernel_name, device.ordinal() as u32) {
Ok(f) => f,
Err(_) => return Ok(None),
};
let mut out = alloc_zeros::<f32>(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(b.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(Some(out))
}
#[cfg(feature = "cuda")]
fn try_launch_unary(
a: &CudaBuffer<f32>,
device: &GpuDevice,
ptx_src: &'static str,
kernel_name: &'static str,
) -> GpuResult<Option<CudaBuffer<f32>>> {
use cudarc::driver::PushKernelArg;
let n = a.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(ctx, ptx_src, kernel_name, device.ordinal() as u32) {
Ok(f) => f,
Err(_) => return Ok(None),
};
let mut out = alloc_zeros::<f32>(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(Some(out))
}
#[cfg(feature = "cuda")]
fn cpu_fallback_binary(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
device: &GpuDevice,
op: fn(f32, f32) -> f32,
) -> GpuResult<CudaBuffer<f32>> {
let a_host = gpu_to_cpu(a, device)?;
let b_host = gpu_to_cpu(b, device)?;
let result: Vec<f32> = a_host
.iter()
.zip(b_host.iter())
.map(|(&x, &y)| op(x, y))
.collect();
cpu_to_gpu(&result, device)
}
#[cfg(feature = "cuda")]
fn cpu_fallback_unary(
a: &CudaBuffer<f32>,
device: &GpuDevice,
op: fn(f32) -> f32,
) -> GpuResult<CudaBuffer<f32>> {
let a_host = gpu_to_cpu(a, device)?;
let result: Vec<f32> = a_host.iter().map(|&x| op(x)).collect();
cpu_to_gpu(&result, device)
}
#[cfg(feature = "cuda")]
pub fn gpu_add(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(a, b, device)?;
if let Some(out) = try_launch_binary(a, b, device, ADD_PTX, "add_kernel")? {
return Ok(out);
}
cpu_fallback_binary(a, b, device, |x, y| x + y)
}
#[cfg(feature = "cuda")]
pub fn gpu_sub(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(a, b, device)?;
if let Some(out) = try_launch_binary(a, b, device, SUB_PTX, "sub_kernel")? {
return Ok(out);
}
cpu_fallback_binary(a, b, device, |x, y| x - y)
}
#[cfg(feature = "cuda")]
pub fn gpu_mul(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(a, b, device)?;
if let Some(out) = try_launch_binary(a, b, device, MUL_PTX, "mul_kernel")? {
return Ok(out);
}
cpu_fallback_binary(a, b, device, |x, y| x * y)
}
#[cfg(feature = "cuda")]
pub fn gpu_neg(
a: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_unary(a, device)?;
if let Some(out) = try_launch_unary(a, device, NEG_PTX, "neg_kernel")? {
return Ok(out);
}
cpu_fallback_unary(a, device, |x| -x)
}
#[cfg(feature = "cuda")]
pub fn gpu_relu(
a: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_unary(a, device)?;
if let Some(out) = try_launch_unary(a, device, RELU_PTX, "relu_kernel")? {
return Ok(out);
}
cpu_fallback_unary(a, device, |x| x.max(0.0))
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_add(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_sub(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_mul(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_neg(
_a: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_relu(
_a: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(test)]
#[cfg(feature = "cuda")]
mod tests {
use super::*;
fn setup(data: &[f32]) -> (GpuDevice, CudaBuffer<f32>) {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let buf = cpu_to_gpu(data, &dev).expect("cpu_to_gpu");
(dev, buf)
}
fn assert_buf_eq(buf: &CudaBuffer<f32>, device: &GpuDevice, expected: &[f32]) {
let host = gpu_to_cpu(buf, device).expect("gpu_to_cpu");
assert_eq!(host.len(), expected.len(), "length mismatch");
for (i, (&got, &exp)) in host.iter().zip(expected.iter()).enumerate() {
assert!(
(got - exp).abs() < 1e-6,
"element {i}: got {got}, expected {exp}",
);
}
}
#[test]
fn add_basic() {
let a_data = vec![1.0f32, 2.0, 3.0, 4.0];
let b_data = vec![10.0f32, 20.0, 30.0, 40.0];
let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x + y).collect();
let (dev, a) = setup(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let out = gpu_add(&a, &b, &dev).expect("gpu_add");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn add_empty() {
let (dev, a) = setup(&[]);
let b = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu b");
let out = gpu_add(&a, &b, &dev).expect("gpu_add empty");
assert_eq!(out.len(), 0);
}
#[test]
fn add_large() {
let n = 100_000;
let a_data: Vec<f32> = (0..n).map(|i| i as f32).collect();
let b_data: Vec<f32> = (0..n).map(|i| (i as f32) * 0.5).collect();
let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x + y).collect();
let (dev, a) = setup(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let out = gpu_add(&a, &b, &dev).expect("gpu_add large");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn add_length_mismatch() {
let (dev, a) = setup(&[1.0, 2.0, 3.0]);
let b = cpu_to_gpu(&[1.0, 2.0], &dev).expect("cpu_to_gpu b");
let err = gpu_add(&a, &b, &dev).unwrap_err();
match err {
GpuError::LengthMismatch { a: 3, b: 2 } => {}
other => panic!("unexpected error: {other}"),
}
}
#[test]
fn sub_basic() {
let a_data = vec![10.0f32, 20.0, 30.0, 40.0];
let b_data = vec![1.0f32, 2.0, 3.0, 4.0];
let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x - y).collect();
let (dev, a) = setup(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let out = gpu_sub(&a, &b, &dev).expect("gpu_sub");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn sub_negative_result() {
let a_data = vec![1.0f32, 2.0];
let b_data = vec![5.0f32, 10.0];
let expected: Vec<f32> = vec![-4.0, -8.0];
let (dev, a) = setup(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let out = gpu_sub(&a, &b, &dev).expect("gpu_sub");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn mul_basic() {
let a_data = vec![2.0f32, 3.0, 4.0, 5.0];
let b_data = vec![10.0f32, 10.0, 10.0, 10.0];
let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x * y).collect();
let (dev, a) = setup(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let out = gpu_mul(&a, &b, &dev).expect("gpu_mul");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn mul_by_zero() {
let a_data = vec![1.0f32, 2.0, 3.0];
let b_data = vec![0.0f32, 0.0, 0.0];
let expected = vec![0.0f32, 0.0, 0.0];
let (dev, a) = setup(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let out = gpu_mul(&a, &b, &dev).expect("gpu_mul");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn neg_basic() {
let a_data = vec![1.0f32, -2.0, 3.0, 0.0, -5.5];
let expected: Vec<f32> = a_data.iter().map(|x| -x).collect();
let (dev, a) = setup(&a_data);
let out = gpu_neg(&a, &dev).expect("gpu_neg");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn neg_double_negation() {
let a_data = vec![1.0f32, -2.0, 3.0];
let (dev, a) = setup(&a_data);
let neg1 = gpu_neg(&a, &dev).expect("gpu_neg 1");
let neg2 = gpu_neg(&neg1, &dev).expect("gpu_neg 2");
assert_buf_eq(&neg2, &dev, &a_data);
}
#[test]
fn relu_basic() {
let a_data = vec![-3.0f32, -1.0, 0.0, 1.0, 3.0];
let expected = vec![0.0f32, 0.0, 0.0, 1.0, 3.0];
let (dev, a) = setup(&a_data);
let out = gpu_relu(&a, &dev).expect("gpu_relu");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn relu_all_negative() {
let a_data = vec![-5.0f32, -0.1, -100.0];
let expected = vec![0.0f32, 0.0, 0.0];
let (dev, a) = setup(&a_data);
let out = gpu_relu(&a, &dev).expect("gpu_relu");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn relu_all_positive() {
let a_data = vec![0.1f32, 1.0, 100.0];
let (dev, a) = setup(&a_data);
let out = gpu_relu(&a, &dev).expect("gpu_relu");
assert_buf_eq(&out, &dev, &a_data);
}
#[test]
fn relu_empty() {
let (dev, a) = setup(&[]);
let out = gpu_relu(&a, &dev).expect("gpu_relu empty");
assert_eq!(out.len(), 0);
}
}