#![cfg(feature = "cuda")]
use cudarc::driver::{CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits};
use crate::buffer::CudaBuffer;
use crate::device::GpuDevice;
use crate::error::{GpuError, GpuResult};
use crate::module_cache::get_or_compile;
use crate::transfer::{alloc_zeros_f32, alloc_zeros_f64};
const BLOCK_SIZE: u32 = 256;
fn launch_1d(n: usize) -> LaunchConfig {
let grid = ((n as u32).saturating_add(BLOCK_SIZE - 1)) / BLOCK_SIZE;
LaunchConfig {
grid_dim: (grid.max(1), 1, 1),
block_dim: (BLOCK_SIZE, 1, 1),
shared_mem_bytes: 0,
}
}
const DIAG_EMBED_F32_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry diag_embed_f32_kernel(
.param .u64 in_ptr, .param .u64 out_ptr,
.param .u32 n, .param .u32 size, .param .s32 k
) {
.reg .u32 %gtid, %bid, %bdim, %tdx, %n, %size, %r, %c, %lin;
.reg .s32 %k_r, %i_s, %absk;
.reg .u64 %in, %out, %off, %addr;
.reg .f32 %v;
.reg .pred %p, %kneg;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n, [n];
ld.param.u32 %size, [size];
ld.param.s32 %k_r, [k];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %tdx, %tid.x;
mad.lo.u32 %gtid, %bid, %bdim, %tdx;
setp.ge.u32 %p, %gtid, %n;
@%p bra DONE;
cvt.s32.u32 %i_s, %gtid;
setp.lt.s32 %kneg, %k_r, 0;
@%kneg bra KNEG;
// k >= 0: r = i, c = i + k
mov.u32 %r, %gtid;
add.s32 %i_s, %i_s, %k_r;
cvt.u32.s32 %c, %i_s;
bra COMPUTE;
KNEG:
// k < 0: r = i + (-k), c = i
sub.s32 %absk, 0, %k_r;
add.s32 %i_s, %i_s, %absk;
cvt.u32.s32 %r, %i_s;
mov.u32 %c, %gtid;
COMPUTE:
mad.lo.u32 %lin, %r, %size, %c;
// load in[gtid]
cvt.u64.u32 %off, %gtid;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
ld.global.f32 %v, [%addr];
// store out[lin]
cvt.u64.u32 %off, %lin;
shl.b64 %off, %off, 2;
add.u64 %addr, %out, %off;
st.global.f32 [%addr], %v;
DONE:
ret;
}
";
const DIAG_EMBED_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry diag_embed_f64_kernel(
.param .u64 in_ptr, .param .u64 out_ptr,
.param .u32 n, .param .u32 size, .param .s32 k
) {
.reg .u32 %gtid, %bid, %bdim, %tdx, %n, %size, %r, %c, %lin;
.reg .s32 %k_r, %i_s, %absk;
.reg .u64 %in, %out, %off, %addr;
.reg .f64 %v;
.reg .pred %p, %kneg;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n, [n];
ld.param.u32 %size, [size];
ld.param.s32 %k_r, [k];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %tdx, %tid.x;
mad.lo.u32 %gtid, %bid, %bdim, %tdx;
setp.ge.u32 %p, %gtid, %n;
@%p bra DONE;
cvt.s32.u32 %i_s, %gtid;
setp.lt.s32 %kneg, %k_r, 0;
@%kneg bra KNEG;
mov.u32 %r, %gtid;
add.s32 %i_s, %i_s, %k_r;
cvt.u32.s32 %c, %i_s;
bra COMPUTE;
KNEG:
sub.s32 %absk, 0, %k_r;
add.s32 %i_s, %i_s, %absk;
cvt.u32.s32 %r, %i_s;
mov.u32 %c, %gtid;
COMPUTE:
mad.lo.u32 %lin, %r, %size, %c;
cvt.u64.u32 %off, %gtid;
shl.b64 %off, %off, 3;
add.u64 %addr, %in, %off;
ld.global.f64 %v, [%addr];
cvt.u64.u32 %off, %lin;
shl.b64 %off, %off, 3;
add.u64 %addr, %out, %off;
st.global.f64 [%addr], %v;
DONE:
ret;
}
";
const DIAG_EXTRACT_F32_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry diag_extract_f32_kernel(
.param .u64 in_ptr, .param .u64 out_ptr,
.param .u32 diag_len, .param .u32 cols, .param .u32 start_r, .param .u32 start_c
) {
.reg .u32 %gtid, %bid, %bdim, %tdx, %dl, %cols, %sr, %sc, %row, %col, %lin;
.reg .u64 %in, %out, %off, %addr;
.reg .f32 %v;
.reg .pred %p;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %dl, [diag_len];
ld.param.u32 %cols, [cols];
ld.param.u32 %sr, [start_r];
ld.param.u32 %sc, [start_c];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %tdx, %tid.x;
mad.lo.u32 %gtid, %bid, %bdim, %tdx;
setp.ge.u32 %p, %gtid, %dl;
@%p bra DONE;
add.u32 %row, %sr, %gtid;
add.u32 %col, %sc, %gtid;
mad.lo.u32 %lin, %row, %cols, %col;
cvt.u64.u32 %off, %lin;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
ld.global.f32 %v, [%addr];
cvt.u64.u32 %off, %gtid;
shl.b64 %off, %off, 2;
add.u64 %addr, %out, %off;
st.global.f32 [%addr], %v;
DONE:
ret;
}
";
const DIAG_EXTRACT_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry diag_extract_f64_kernel(
.param .u64 in_ptr, .param .u64 out_ptr,
.param .u32 diag_len, .param .u32 cols, .param .u32 start_r, .param .u32 start_c
) {
.reg .u32 %gtid, %bid, %bdim, %tdx, %dl, %cols, %sr, %sc, %row, %col, %lin;
.reg .u64 %in, %out, %off, %addr;
.reg .f64 %v;
.reg .pred %p;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %dl, [diag_len];
ld.param.u32 %cols, [cols];
ld.param.u32 %sr, [start_r];
ld.param.u32 %sc, [start_c];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %tdx, %tid.x;
mad.lo.u32 %gtid, %bid, %bdim, %tdx;
setp.ge.u32 %p, %gtid, %dl;
@%p bra DONE;
add.u32 %row, %sr, %gtid;
add.u32 %col, %sc, %gtid;
mad.lo.u32 %lin, %row, %cols, %col;
cvt.u64.u32 %off, %lin;
shl.b64 %off, %off, 3;
add.u64 %addr, %in, %off;
ld.global.f64 %v, [%addr];
cvt.u64.u32 %off, %gtid;
shl.b64 %off, %off, 3;
add.u64 %addr, %out, %off;
st.global.f64 [%addr], %v;
DONE:
ret;
}
";
fn launch_diag_embed<V: DeviceRepr + ValidAsZeroBits>(
in_slice: &CudaSlice<V>,
out_slice: &mut CudaSlice<V>,
n: usize,
size: usize,
k: i64,
device: &GpuDevice,
ptx: &'static str,
kernel_name: &'static str,
) -> GpuResult<()> {
if n == 0 {
return Ok(());
}
if in_slice.len() < n {
return Err(GpuError::LengthMismatch {
a: in_slice.len(),
b: n,
});
}
let total = size
.checked_mul(size)
.ok_or(GpuError::LengthMismatch { a: size, b: size })?;
if out_slice.len() < total {
return Err(GpuError::LengthMismatch {
a: out_slice.len(),
b: total,
});
}
let stream = device.stream();
let ctx = device.context();
let f = get_or_compile(ctx, ptx, kernel_name, device.ordinal() as u32).map_err(|e| {
GpuError::PtxCompileFailed {
kernel: kernel_name,
source: e,
}
})?;
let cfg = launch_1d(n);
let n_u = n as u32;
let size_u = size as u32;
let k_i32 = k.clamp(i32::MIN as i64, i32::MAX as i64) as i32;
unsafe {
stream
.launch_builder(&f)
.arg(in_slice)
.arg(out_slice)
.arg(&n_u)
.arg(&size_u)
.arg(&k_i32)
.launch(cfg)?;
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn launch_diag_extract<V: DeviceRepr + ValidAsZeroBits>(
in_slice: &CudaSlice<V>,
out_slice: &mut CudaSlice<V>,
rows: usize,
cols: usize,
diag_len: usize,
start_r: usize,
start_c: usize,
device: &GpuDevice,
ptx: &'static str,
kernel_name: &'static str,
) -> GpuResult<()> {
if diag_len == 0 {
return Ok(());
}
let in_total = rows
.checked_mul(cols)
.ok_or(GpuError::LengthMismatch { a: rows, b: cols })?;
if in_slice.len() < in_total {
return Err(GpuError::LengthMismatch {
a: in_slice.len(),
b: in_total,
});
}
if out_slice.len() < diag_len {
return Err(GpuError::LengthMismatch {
a: out_slice.len(),
b: diag_len,
});
}
let stream = device.stream();
let ctx = device.context();
let f = get_or_compile(ctx, ptx, kernel_name, device.ordinal() as u32).map_err(|e| {
GpuError::PtxCompileFailed {
kernel: kernel_name,
source: e,
}
})?;
let cfg = launch_1d(diag_len);
let dl_u = diag_len as u32;
let cols_u = cols as u32;
let sr_u = start_r as u32;
let sc_u = start_c as u32;
unsafe {
stream
.launch_builder(&f)
.arg(in_slice)
.arg(out_slice)
.arg(&dl_u)
.arg(&cols_u)
.arg(&sr_u)
.arg(&sc_u)
.launch(cfg)?;
}
Ok(())
}
pub fn gpu_diag_embed_f32(
input: &CudaBuffer<f32>,
n: usize,
k: i64,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
let size = n + k.unsigned_abs() as usize;
let mut out = alloc_zeros_f32(size * size, device)?;
launch_diag_embed(
input.inner(),
out.inner_mut(),
n,
size,
k,
device,
DIAG_EMBED_F32_PTX,
"diag_embed_f32_kernel",
)?;
Ok(out)
}
pub fn gpu_diag_embed_f64(
input: &CudaBuffer<f64>,
n: usize,
k: i64,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
let size = n + k.unsigned_abs() as usize;
let mut out = alloc_zeros_f64(size * size, device)?;
launch_diag_embed(
input.inner(),
out.inner_mut(),
n,
size,
k,
device,
DIAG_EMBED_F64_PTX,
"diag_embed_f64_kernel",
)?;
Ok(out)
}
pub fn gpu_diag_extract_f32(
input: &CudaBuffer<f32>,
rows: usize,
cols: usize,
k: i64,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
let (start_r, start_c) = if k >= 0 {
(0usize, k as usize)
} else {
(k.unsigned_abs() as usize, 0usize)
};
let diag_len = rows
.saturating_sub(start_r)
.min(cols.saturating_sub(start_c));
let mut out = alloc_zeros_f32(diag_len.max(1), device)?;
launch_diag_extract(
input.inner(),
out.inner_mut(),
rows,
cols,
diag_len,
start_r,
start_c,
device,
DIAG_EXTRACT_F32_PTX,
"diag_extract_f32_kernel",
)?;
Ok(out)
}
pub fn gpu_diag_extract_f64(
input: &CudaBuffer<f64>,
rows: usize,
cols: usize,
k: i64,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
let (start_r, start_c) = if k >= 0 {
(0usize, k as usize)
} else {
(k.unsigned_abs() as usize, 0usize)
};
let diag_len = rows
.saturating_sub(start_r)
.min(cols.saturating_sub(start_c));
let mut out = alloc_zeros_f64(diag_len.max(1), device)?;
launch_diag_extract(
input.inner(),
out.inner_mut(),
rows,
cols,
diag_len,
start_r,
start_c,
device,
DIAG_EXTRACT_F64_PTX,
"diag_extract_f64_kernel",
)?;
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transfer::{cpu_to_gpu, gpu_to_cpu};
fn dev() -> GpuDevice {
GpuDevice::new(0).expect("cuda device")
}
fn cpu_embed(data: &[f32], k: i64) -> Vec<f32> {
let n = data.len();
let offset = k.unsigned_abs() as usize;
let size = n + offset;
let mut out = vec![0.0f32; size * size];
for (i, &val) in data.iter().enumerate() {
let (r, c) = if k >= 0 {
(i, i + offset)
} else {
(i + offset, i)
};
out[r * size + c] = val;
}
out
}
#[test]
fn diag_embed_f32_main() {
let d = dev();
let data = vec![1.0f32, 2.0, 3.0];
let h = cpu_to_gpu(&data, &d).unwrap();
let out = gpu_diag_embed_f32(&h, 3, 0, &d).unwrap();
let got = gpu_to_cpu(&out, &d).unwrap();
assert_eq!(&got[..9], &cpu_embed(&data, 0)[..]);
assert_eq!(&got[..9], &[1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]);
}
#[test]
fn diag_embed_f32_positive_offset() {
let d = dev();
let data = vec![1.0f32, 2.0];
let h = cpu_to_gpu(&data, &d).unwrap();
let out = gpu_diag_embed_f32(&h, 2, 1, &d).unwrap();
let got = gpu_to_cpu(&out, &d).unwrap();
assert_eq!(&got[..9], &cpu_embed(&data, 1)[..]);
assert_eq!(&got[..9], &[0.0, 1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0]);
}
#[test]
fn diag_embed_f32_negative_offset() {
let d = dev();
let data = vec![1.0f32, 2.0];
let h = cpu_to_gpu(&data, &d).unwrap();
let out = gpu_diag_embed_f32(&h, 2, -1, &d).unwrap();
let got = gpu_to_cpu(&out, &d).unwrap();
assert_eq!(&got[..9], &cpu_embed(&data, -1)[..]);
assert_eq!(&got[..9], &[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 2.0, 0.0]);
}
#[test]
fn diag_extract_f32_main() {
let d = dev();
let data: Vec<f32> = (1..=9).map(|i| i as f32).collect();
let h = cpu_to_gpu(&data, &d).unwrap();
let out = gpu_diag_extract_f32(&h, 3, 3, 0, &d).unwrap();
let got = gpu_to_cpu(&out, &d).unwrap();
assert_eq!(&got[..3], &[1.0, 5.0, 9.0]);
}
#[test]
fn diag_extract_f32_positive_offset() {
let d = dev();
let data: Vec<f32> = (1..=9).map(|i| i as f32).collect();
let h = cpu_to_gpu(&data, &d).unwrap();
let out = gpu_diag_extract_f32(&h, 3, 3, 1, &d).unwrap();
let got = gpu_to_cpu(&out, &d).unwrap();
assert_eq!(&got[..2], &[2.0, 6.0]);
}
#[test]
fn diag_extract_f32_negative_offset() {
let d = dev();
let data: Vec<f32> = (1..=9).map(|i| i as f32).collect();
let h = cpu_to_gpu(&data, &d).unwrap();
let out = gpu_diag_extract_f32(&h, 3, 3, -1, &d).unwrap();
let got = gpu_to_cpu(&out, &d).unwrap();
assert_eq!(&got[..2], &[4.0, 8.0]);
}
#[test]
fn diag_embed_f64_main() {
let d = dev();
let data = vec![1.0f64, 2.0, 3.0];
let h = cpu_to_gpu(&data, &d).unwrap();
let out = gpu_diag_embed_f64(&h, 3, 0, &d).unwrap();
let got = gpu_to_cpu(&out, &d).unwrap();
assert_eq!(&got[..9], &[1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]);
}
#[test]
fn diag_extract_f64_main() {
let d = dev();
let data: Vec<f64> = (1..=9).map(|i| i as f64).collect();
let h = cpu_to_gpu(&data, &d).unwrap();
let out = gpu_diag_extract_f64(&h, 3, 3, 0, &d).unwrap();
let got = gpu_to_cpu(&out, &d).unwrap();
assert_eq!(&got[..3], &[1.0, 5.0, 9.0]);
}
}