#![allow(non_snake_case)]
#![allow(non_camel_case_types)]
use std::{
ffi::{c_char, c_int, c_uint, c_void},
hash::BuildHasherDefault,
ptr,
sync::mpsc::{Receiver, Sender, channel},
};
use libloading::Library;
use nanoserde::DeJson;
use crate::{
DType, Map,
dtype::Constant,
error::{BackendError, ErrorStatus},
graph::{BOp, UOp},
kernel::{IDX_T, Kernel, Op, OpId, Scope},
shape::Dim,
slab::Slab,
};
macro_rules! send_or_continue {
($expr:expr, $tx:expr) => {
match $expr {
Ok(v) => v,
Err(e) => {
let _ = $tx.send(Err(e));
continue;
}
}
};
}
use super::{Device, DeviceId, DeviceInfo, DeviceProgramId, Event, MemoryPool, PoolBufferId, PoolId};
#[derive(Debug, Default, DeJson)]
pub struct CUDAConfig {
device_ids: Option<Vec<i32>>,
}
#[derive(Debug)]
pub struct CUDAMemoryPool {
tx: Sender<CUDACommand>,
free_bytes: Dim,
}
#[derive(Debug)]
pub(super) struct CUDABuffer {
ptr: u64,
bytes: Dim,
}
#[derive(Debug)]
pub struct CUDADevice {
tx: Sender<CUDACommand>,
device: CUdevice,
memory_pool_id: PoolId,
dev_info: DeviceInfo,
compute_capability: [c_int; 2],
}
#[derive(Debug)]
pub(super) struct CUDAProgram {
module: CUmodule,
function: CUfunction,
gws: Vec<Dim>,
lws: Vec<Dim>,
}
#[derive(Debug)]
pub(super) struct CUDAStream {
stream: CUstream,
load: usize,
}
#[derive(Debug, Clone)]
pub struct CUDAEvent {
event: CUevent,
}
unsafe impl Send for CUDAEvent {}
enum CUDACommand {
Allocate {
bytes: usize,
reply: Sender<Result<(PoolBufferId, Event), BackendError>>,
},
Deallocate {
buffer_id: PoolBufferId,
events: Vec<Event>,
},
HostToPool {
src: *const u8,
bytes: usize,
dst: PoolBufferId,
event_wait_list: Vec<Event>,
reply: Sender<Result<Event, BackendError>>,
},
PoolToHost {
src: PoolBufferId,
dst: *mut u8,
bytes: usize,
event_wait_list: Vec<Event>,
reply: Sender<Result<(), BackendError>>,
},
Compile {
gws: Vec<Dim>,
lws: Vec<Dim>,
name: Box<str>,
ptx: Vec<u8>,
reply: Sender<Result<DeviceProgramId, BackendError>>,
},
Launch {
program_id: DeviceProgramId,
args: Vec<PoolBufferId>,
event_wait_list: Vec<Event>,
reply: Sender<Result<Event, BackendError>>,
},
SyncEvents {
events: Vec<Event>,
reply: Sender<Result<(), BackendError>>,
},
ReleaseProgram {
program_id: DeviceProgramId,
},
ReleaseEvents {
events: Vec<Event>,
},
}
unsafe impl Send for CUDACommand {}
pub(super) fn initialize_device(
config: &CUDAConfig,
memory_pools: &mut Slab<PoolId, MemoryPool>,
devices: &mut Slab<DeviceId, Device>,
debug_dev: bool,
) -> Result<(), BackendError> {
if let Some(device_ids) = &config.device_ids
&& device_ids.is_empty()
{
if debug_dev {
println!("CUDA won't be used, as it was configured out");
}
return Ok(());
}
let cuda_paths = [
"/lib/x86_64-linux-gnu/libcuda.so",
"/lib64/x86_64-linux-gnu/libcuda.so",
"/lib/libcuda.so",
"/lib64/libcuda.so",
"/usr/lib/libcuda.so",
"/usr/lib64/libcuda.so",
];
let cuda = cuda_paths.into_iter().find_map(|path| unsafe { Library::new(path) }.ok());
let Some(cuda) = cuda else {
if debug_dev {
println!("libcuda.so not found");
}
return Err(BackendError { status: ErrorStatus::DyLibNotFound, context: "CUDA libcuda.so not found.".into() });
};
let cuInit: unsafe extern "C" fn(c_uint) -> CUDAStatus = *unsafe { cuda.get(b"cuInit\0") }?;
let cuDriverGetVersion: unsafe extern "C" fn(*mut c_int) -> CUDAStatus = *unsafe { cuda.get(b"cuDriverGetVersion\0") }?;
let cuDeviceGetCount: unsafe extern "C" fn(*mut c_int) -> CUDAStatus = *unsafe { cuda.get(b"cuDeviceGetCount\0") }?;
let cuDeviceGet: unsafe extern "C" fn(*mut CUdevice, c_int) -> CUDAStatus = *unsafe { cuda.get(b"cuDeviceGet\0") }?;
let cuDeviceGetName: unsafe extern "C" fn(*mut c_char, c_int, CUdevice) -> CUDAStatus =
*unsafe { cuda.get(b"cuDeviceGetName\0") }?;
let cuDeviceComputeCapability: unsafe extern "C" fn(*mut c_int, *mut c_int, CUdevice) -> CUDAStatus =
*unsafe { cuda.get(b"cuDeviceComputeCapability\0") }?;
let cuDeviceTotalMem: unsafe extern "C" fn(*mut usize, CUdevice) -> CUDAStatus = *unsafe { cuda.get(b"cuDeviceTotalMem\0") }?;
let cuDeviceGetAttribute: unsafe extern "C" fn(*mut c_int, CUdevice_attribute, CUdevice) -> CUDAStatus =
*unsafe { cuda.get(b"cuDeviceGetAttribute\0") }?;
let cuCtxCreate: unsafe extern "C" fn(*mut CUcontext, c_uint, CUdevice) -> CUDAStatus =
*unsafe { cuda.get(b"cuCtxCreate\0") }?;
let cuMemAlloc: unsafe extern "C" fn(*mut CUdeviceptr, usize) -> CUDAStatus = *unsafe { cuda.get(b"cuMemAlloc\0") }?;
let cuMemFree: unsafe extern "C" fn(CUdeviceptr) -> CUDAStatus = *unsafe { cuda.get(b"cuMemFree\0") }?;
let cuMemcpyHtoDAsync: unsafe extern "C" fn(CUdeviceptr, *const c_void, usize, CUstream) -> CUDAStatus =
*unsafe { cuda.get(b"cuMemcpyHtoDAsync\0") }?;
let cuMemcpyDtoHAsync: unsafe extern "C" fn(*mut c_void, CUdeviceptr, usize, CUstream) -> CUDAStatus =
*unsafe { cuda.get(b"cuMemcpyDtoHAsync\0") }?;
let cuModuleLoadDataEx: unsafe extern "C" fn(
*mut CUmodule,
*const c_void,
c_uint,
*mut CUjit_option,
*mut *mut c_void,
) -> CUDAStatus = *unsafe { cuda.get(b"cuModuleLoadDataEx\0") }?;
let cuModuleGetFunction: unsafe extern "C" fn(*mut CUfunction, CUmodule, *const c_char) -> CUDAStatus =
*unsafe { cuda.get(b"cuModuleGetFunction\0") }?;
let cuLaunchKernel: unsafe extern "C" fn(
CUfunction,
c_uint,
c_uint,
c_uint,
c_uint,
c_uint,
c_uint,
c_uint,
CUstream,
*mut *mut c_void,
*mut *mut c_void,
) -> CUDAStatus = *unsafe { cuda.get(b"cuLaunchKernel\0") }?;
let cuStreamCreate: unsafe extern "C" fn(*mut CUstream, c_uint) -> CUDAStatus = *unsafe { cuda.get(b"cuStreamCreate\0") }?;
let cuStreamSynchronize: unsafe extern "C" fn(CUstream) -> CUDAStatus = *unsafe { cuda.get(b"cuStreamSynchronize\0") }?;
let cuStreamWaitEvent: unsafe extern "C" fn(CUstream, CUevent, c_uint) -> CUDAStatus =
*unsafe { cuda.get(b"cuStreamWaitEvent\0") }?;
let cuModuleUnload: unsafe extern "C" fn(CUmodule) -> CUDAStatus = *unsafe { cuda.get(b"cuModuleUnload\0") }?;
let cuEventCreate: unsafe extern "C" fn(*mut CUevent, c_uint) -> CUDAStatus = *unsafe { cuda.get(b"cuEventCreate\0") }?;
let cuEventRecord: unsafe extern "C" fn(CUevent, CUstream) -> CUDAStatus = *unsafe { cuda.get(b"cuEventRecord\0") }?;
let cuEventSynchronize: unsafe extern "C" fn(CUevent) -> CUDAStatus = *unsafe { cuda.get(b"cuEventSynchronize\0") }?;
let cuEventDestroy: unsafe extern "C" fn(CUevent) -> CUDAStatus = *unsafe { cuda.get(b"cuEventDestroy\0") }?;
if let Err(err) = unsafe { cuInit(0) }.check(ErrorStatus::Initialization) {
if debug_dev {
println!("CUDA requested, but cuInit failed. {err:?}");
}
return Err(err);
}
let mut driver_version = 0;
unsafe { cuDriverGetVersion(&raw mut driver_version) }.check(ErrorStatus::DeviceQuery)?;
let mut num_devices = 0;
unsafe { cuDeviceGetCount(&raw mut num_devices) }.check(ErrorStatus::DeviceQuery)?;
if num_devices == 0 {
return Err(BackendError { status: ErrorStatus::DeviceEnumeration, context: "No available cuda device.".into() });
}
let device_ids: Vec<i32> = (0..num_devices)
.filter(|id| config.device_ids.as_ref().is_none_or(|ids| ids.contains(id)))
.collect();
if debug_dev && !device_ids.is_empty() {
println!(
"Using CUDA driver, driver version: {}.{} on devices:",
driver_version / 1000,
(driver_version - (driver_version / 1000 * 1000)) / 10
);
}
for dev_id in device_ids {
let mut device = 0;
if let Err(err) = unsafe { cuDeviceGet(&raw mut device, dev_id) }.check(ErrorStatus::DeviceEnumeration) {
if debug_dev {
println!("Device with id {dev_id} requested, but could not be enumerated: {err}.");
}
continue;
}
let mut device_name = [0; 100];
let Ok(()) = unsafe { cuDeviceGetName(device_name.as_mut_ptr(), 100, device) }.check(ErrorStatus::DeviceQuery) else {
continue;
};
let mut major = 0;
let mut minor = 0;
let Ok(()) = unsafe { cuDeviceComputeCapability(&raw mut major, &raw mut minor, device) }.check(ErrorStatus::DeviceQuery)
else {
continue;
};
if debug_dev {
println!("{:?}, compute capability: {major}.{minor}", unsafe {
std::ffi::CStr::from_ptr(device_name.as_ptr())
});
}
let mut free_bytes = 0;
let Ok(()) = unsafe { cuDeviceTotalMem(&raw mut free_bytes, device) }.check(ErrorStatus::DeviceQuery) else {
continue;
};
let mut context: CUcontext = ptr::null_mut();
if let Err(e) = unsafe { cuCtxCreate(&raw mut context, 0, device) }.check(ErrorStatus::Initialization) {
if debug_dev {
println!("Device with id {dev_id} requested, but cuda context initialization failed. {e:?}");
}
continue;
}
let (tx, rx): (Sender<CUDACommand>, Receiver<CUDACommand>) = channel();
std::thread::spawn(move || {
let mut context: CUcontext = ptr::null_mut();
if let Err(e) = unsafe { cuCtxCreate(&raw mut context, 0, device) }.check(ErrorStatus::Initialization) {
if debug_dev {
println!("Cuda context initialization failed. {e:?}");
}
return;
}
let mut streams = Vec::new();
for _ in 0..8 {
let mut stream = ptr::null_mut();
if let Err(err) = unsafe { cuStreamCreate(&raw mut stream, 0) }.check(ErrorStatus::Initialization) {
if debug_dev {
println!("Device with id {dev_id} requested, but cuda stream initialization failed. {err:?}");
}
continue;
}
streams.push(CUDAStream { stream, load: 0 });
}
let mut buffers: Slab<PoolBufferId, CUDABuffer> = Slab::new();
let mut programs: Slab<DeviceProgramId, CUDAProgram> = Slab::new();
'work_thread_loop: while let Ok(cmd) = rx.recv() {
match cmd {
CUDACommand::Allocate { bytes, reply } => {
let stream = next_stream(&mut streams, cuStreamSynchronize);
let mut ptr = u64::try_from(device).expect("What is a negative cuda device?");
let mut event = ptr::null_mut();
send_or_continue!(
unsafe { (cuEventCreate)(&raw mut event, 0x2) }.check(ErrorStatus::MemoryAllocation),
reply
);
debug_assert!(!stream.is_null());
send_or_continue!(
unsafe { (cuMemAlloc)(&raw mut ptr, bytes as usize) }.check(ErrorStatus::MemoryAllocation),
reply
);
if ptr % 8 != 0 {
panic!("Memory is not 8-byte aligned!");
}
send_or_continue!(
unsafe { (cuEventRecord)(event, stream) }.check(ErrorStatus::MemoryAllocation),
reply
);
debug_assert!(free_bytes > bytes);
free_bytes = free_bytes.saturating_sub(bytes);
let buffer_id = buffers.push(CUDABuffer { ptr, bytes });
let event = Event::CUDA(CUDAEvent { event });
let _ = reply.send(Ok((buffer_id, event)));
}
CUDACommand::Deallocate { buffer_id, mut events } => {
let stream = next_stream(&mut streams, cuStreamSynchronize);
while let Some(Event::CUDA(CUDAEvent { event })) = events.pop() {
if !event.is_null() {
_ = unsafe { (cuStreamWaitEvent)(stream, event, 0) }.check(ErrorStatus::MemoryDeallocation);
_ = unsafe { (cuEventDestroy)(event) }.check(ErrorStatus::MemoryCopyP2H);
}
}
let buffer = &mut buffers[buffer_id];
_ = unsafe { (cuMemFree)(buffer.ptr) }.check(ErrorStatus::MemoryDeallocation);
free_bytes += buffer.bytes;
buffers.remove(buffer_id);
}
CUDACommand::HostToPool { src, bytes, dst, mut event_wait_list, reply } => {
let stream = next_stream(&mut streams, cuStreamSynchronize);
let dst = &buffers[dst];
while let Some(Event::CUDA(CUDAEvent { event })) = event_wait_list.pop() {
if !event.is_null() {
send_or_continue!(
unsafe { (cuStreamWaitEvent)(stream, event, 0) }.check(ErrorStatus::MemoryCopyH2P),
reply
);
}
}
let mut event = ptr::null_mut();
send_or_continue!(
unsafe { (cuEventCreate)(&raw mut event, 0x2) }.check(ErrorStatus::MemoryCopyH2P),
reply
);
debug_assert!(!stream.is_null());
send_or_continue!(
unsafe { (cuMemcpyHtoDAsync)(dst.ptr, src.cast(), bytes, stream) }.check(ErrorStatus::MemoryCopyH2P),
reply
);
send_or_continue!(
unsafe { (cuEventRecord)(event, stream) }.check(ErrorStatus::MemoryCopyH2P),
reply
);
_ = reply.send(Ok(Event::CUDA(CUDAEvent { event })));
}
CUDACommand::PoolToHost { src, dst, bytes, mut event_wait_list, reply } => {
let stream = next_stream(&mut streams, cuStreamSynchronize);
while let Some(Event::CUDA(CUDAEvent { event })) = event_wait_list.pop() {
if !event.is_null() {
send_or_continue!(
unsafe { (cuStreamWaitEvent)(stream, event, 0) }.check(ErrorStatus::MemoryCopyP2H),
reply
);
}
}
let src = &buffers[src];
let mut event = ptr::null_mut();
send_or_continue!(
unsafe { (cuEventCreate)(&raw mut event, 0x2) }.check(ErrorStatus::MemoryCopyP2H),
reply
);
send_or_continue!(
unsafe { (cuMemcpyDtoHAsync)(dst.cast(), src.ptr, bytes, stream) }.check(ErrorStatus::MemoryCopyP2H),
reply
);
send_or_continue!(
unsafe { (cuEventRecord)(event, stream) }.check(ErrorStatus::MemoryCopyP2H),
reply
);
send_or_continue!(
unsafe { (cuEventSynchronize)(event) }.check(ErrorStatus::MemoryCopyP2H),
reply
);
send_or_continue!(unsafe { (cuEventDestroy)(event) }.check(ErrorStatus::MemoryCopyP2H), reply);
_ = reply.send(Ok(()));
}
CUDACommand::Compile { gws, lws, name, ptx, reply } => {
let mut module = ptr::null_mut();
if let Err(err) = unsafe {
(cuModuleLoadDataEx)(&raw mut module, ptx.as_ptr().cast(), 0, ptr::null_mut(), ptr::null_mut())
}
.check(ErrorStatus::KernelCompilation)
{
if debug_dev {
println!("Failed to compile kernel with err: {err:?}");
}
_ = reply.send(Err(err));
continue;
}
let mut function: CUfunction = ptr::null_mut();
if let Err(err) = unsafe { (cuModuleGetFunction)(&raw mut function, module, name.as_ptr().cast()) }
.check(ErrorStatus::KernelCompilation)
{
if debug_dev {
println!("Failed to launch kernel with err: {err:?}\n");
}
_ = reply.send(Err(err));
continue;
}
let program_id = programs.push(CUDAProgram {
module,
function,
gws,
lws,
});
_ = reply.send(Ok(program_id));
}
CUDACommand::Launch { program_id, args, mut event_wait_list, reply } => {
let stream = next_stream(&mut streams, cuStreamSynchronize);
let program = &programs[program_id];
let mut kernel_params: Vec<*mut core::ffi::c_void> = Vec::new();
for arg in args {
let arg = &buffers[arg];
let ptr: *const u64 = &raw const arg.ptr;
let ptr: *mut u64 = ptr.cast_mut();
kernel_params.push(ptr.cast());
}
while let Some(Event::CUDA(CUDAEvent { event })) = event_wait_list.pop() {
if !event.is_null() {
if let Err(err) =
unsafe { (cuStreamWaitEvent)(stream, event, 0) }.check(ErrorStatus::KernelLaunch)
{
_ = reply.send(Err(err));
continue 'work_thread_loop;
};
}
}
let mut event = ptr::null_mut();
if let Err(err) = unsafe { (cuEventCreate)(&raw mut event, 0) }.check(ErrorStatus::KernelLaunch) {
_ = reply.send(Err(err));
continue;
};
send_or_continue!(
unsafe {
(cuLaunchKernel)(
program.function,
u32::try_from(program.gws.get(0).copied().unwrap_or(1)).unwrap(),
u32::try_from(program.gws.get(1).copied().unwrap_or(1)).unwrap(),
u32::try_from(program.gws.get(2).copied().unwrap_or(1)).unwrap(),
u32::try_from(program.lws.get(0).copied().unwrap_or(1)).unwrap(),
u32::try_from(program.lws.get(1).copied().unwrap_or(1)).unwrap(),
u32::try_from(program.lws.get(2).copied().unwrap_or(1)).unwrap(),
0,
stream,
kernel_params.as_mut_ptr(),
ptr::null_mut(),
)
}
.check(ErrorStatus::KernelLaunch),
reply
);
if let Err(err) = unsafe { (cuEventRecord)(event, stream) }.check(ErrorStatus::KernelLaunch) {
_ = reply.send(Err(err));
continue;
}
_ = reply.send(Ok(Event::CUDA(CUDAEvent { event })));
}
CUDACommand::SyncEvents { mut events, reply } => {
while let Some(Event::CUDA(CUDAEvent { event })) = events.pop() {
if !event.is_null() {
if let Err(err) = unsafe { (cuEventSynchronize)(event) }.check(ErrorStatus::KernelSync) {
_ = reply.send(Err(err));
continue;
}
if let Err(err) = unsafe { (cuEventDestroy)(event) }.check(ErrorStatus::KernelSync) {
_ = reply.send(Err(err));
continue;
}
}
}
_ = reply.send(Ok(()));
}
CUDACommand::ReleaseProgram { program_id } => {
let _ = unsafe { (cuModuleUnload)(programs[program_id].module) }.check(ErrorStatus::Deinitialization);
programs.remove(program_id);
}
CUDACommand::ReleaseEvents { events } => {
for event in events {
let Event::CUDA(CUDAEvent { event }) = event else {
unreachable!()
};
_ = unsafe { (cuEventDestroy)(event) }.check(ErrorStatus::Deinitialization);
}
}
}
}
});
let pool = MemoryPool::CUDA(CUDAMemoryPool { tx: tx.clone(), free_bytes });
memory_pools.push(pool);
let mut dev = CUDADevice {
tx,
device,
dev_info: DeviceInfo {
compute: 1024 * 1024 * 1024 * 1024,
max_global_work_dims: vec![64, 64, 64],
max_local_threads: 1,
max_local_work_dims: vec![1, 1, 1],
local_mem_size: 0,
max_register_bytes: 96,
preferred_vector_size: 16,
tensor_cores: major > 7,
warp_size: 32,
supported_dtypes: u32::MAX,
},
memory_pool_id: PoolId::from(usize::from(memory_pools.len()) - 1),
compute_capability: [major, minor],
};
dev.dev_info = DeviceInfo {
compute: 1024 * 1024 * 1024 * 1024, max_global_work_dims: vec![
Dim::try_from(dev.get(CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X, cuDeviceGetAttribute)?).unwrap(),
Dim::try_from(dev.get(CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y, cuDeviceGetAttribute)?).unwrap(),
Dim::try_from(dev.get(CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z, cuDeviceGetAttribute)?).unwrap(),
],
max_local_threads: Dim::try_from(dev.get(
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK,
cuDeviceGetAttribute,
)?)
.unwrap(),
max_local_work_dims: vec![
Dim::try_from(dev.get(CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X, cuDeviceGetAttribute)?).unwrap(),
Dim::try_from(dev.get(CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y, cuDeviceGetAttribute)?).unwrap(),
Dim::try_from(dev.get(CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z, cuDeviceGetAttribute)?).unwrap(),
],
local_mem_size: Dim::try_from(dev.get(
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK,
cuDeviceGetAttribute,
)?)
.unwrap(),
max_register_bytes: 96,
preferred_vector_size: 16,
tensor_cores: major > 7,
warp_size: 32,
supported_dtypes: u32::MAX,
};
devices.push(Device::CUDA(dev));
}
Ok(())
}
impl CUDAMemoryPool {
#[allow(clippy::needless_pass_by_ref_mut)]
pub const fn deinitialize(&mut self) {
let _ = self;
}
pub const fn free_bytes(&self) -> Dim {
self.free_bytes
}
#[allow(clippy::needless_pass_by_ref_mut)]
pub fn allocate(&mut self, bytes: Dim) -> Result<(PoolBufferId, Event), BackendError> {
if bytes > self.free_bytes {
return Err(BackendError { status: ErrorStatus::MemoryAllocation, context: "Allocation failure.".into() });
}
let (reply, reply_rx) = channel();
self.tx.send(CUDACommand::Allocate { bytes, reply }).unwrap();
reply_rx.recv().unwrap()
}
#[allow(clippy::needless_pass_by_ref_mut)]
pub fn deallocate(&mut self, buffer_id: PoolBufferId, events: Vec<Event>) {
self.tx.send(CUDACommand::Deallocate { buffer_id, events }).unwrap();
}
#[allow(clippy::needless_pass_by_ref_mut)]
pub fn host_to_pool(&mut self, src: &[u8], dst: PoolBufferId, event_wait_list: Vec<Event>) -> Result<Event, BackendError> {
let (reply, reply_rx) = channel();
self.tx
.send(CUDACommand::HostToPool { src: src.as_ptr(), bytes: src.len(), dst, event_wait_list, reply })
.unwrap();
reply_rx.recv().unwrap()
}
#[allow(clippy::needless_pass_by_ref_mut)]
pub fn pool_to_host(&mut self, src: PoolBufferId, dst: &mut [u8], event_wait_list: Vec<Event>) -> Result<(), BackendError> {
let (reply, reply_rx) = channel();
self.tx
.send(CUDACommand::PoolToHost { src, dst: dst.as_mut_ptr(), bytes: dst.len(), event_wait_list, reply })
.unwrap();
reply_rx.recv().unwrap()
}
#[allow(clippy::needless_pass_by_ref_mut)]
pub fn sync_events(&mut self, events: Vec<Event>) -> Result<(), BackendError> {
let (reply, reply_rx) = channel();
self.tx.send(CUDACommand::SyncEvents { events, reply }).unwrap();
reply_rx.recv().unwrap()
}
#[allow(clippy::needless_pass_by_ref_mut)]
pub fn release_events(&mut self, events: Vec<Event>) {
self.tx.send(CUDACommand::ReleaseEvents { events }).unwrap();
}
}
impl CUDADevice {
#[allow(clippy::needless_pass_by_ref_mut)]
pub const fn deinitialize(&mut self) {
let _ = self;
}
pub const fn info(&self) -> &DeviceInfo {
&self.dev_info
}
pub const fn memory_pool_id(&self) -> PoolId {
self.memory_pool_id
}
pub const fn free_compute(&self) -> u128 {
self.dev_info.compute
}
#[allow(clippy::needless_pass_by_ref_mut)]
pub fn compile(&mut self, kernel: &Kernel, debug_asm: bool) -> Result<DeviceProgramId, BackendError> {
let (ptx, name, gws, lws) = Compiler::new().compile(kernel, self.compute_capability, &self.dev_info, debug_asm)?;
let (reply, reply_rx) = channel();
self.tx.send(CUDACommand::Compile { gws, lws, name, ptx, reply }).unwrap();
reply_rx.recv().unwrap()
}
#[allow(clippy::needless_pass_by_ref_mut)]
pub fn launch(
&mut self,
program_id: DeviceProgramId,
_memory_pool: &mut CUDAMemoryPool,
args: &[PoolBufferId],
event_wait_list: Vec<Event>,
) -> Result<Event, BackendError> {
let (reply, reply_rx) = channel();
self.tx
.send(CUDACommand::Launch { program_id, args: args.into(), event_wait_list, reply })
.unwrap();
reply_rx.recv().unwrap()
}
#[allow(clippy::needless_pass_by_ref_mut)]
pub fn release(&mut self, program_id: DeviceProgramId) {
self.tx.send(CUDACommand::ReleaseProgram { program_id }).unwrap();
}
}
fn next_stream(
streams: &mut [CUDAStream],
cuStreamSynchronize: unsafe extern "C" fn(CUstream) -> CUDAStatus,
) -> *mut CUstream_st {
let mut id = streams.iter().enumerate().min_by_key(|(_, s)| s.load).unwrap().0;
if streams[id].load > 20 {
let stream_sync = unsafe { (cuStreamSynchronize)(streams[id].stream) }.check(ErrorStatus::KernelSync);
if stream_sync.is_ok() {
streams[id].load = 0;
}
id = streams.iter().enumerate().min_by_key(|(_, q)| q.load).unwrap().0;
}
streams[id].stream
}
impl CUDADevice {
fn get(
&mut self,
attr: CUdevice_attribute,
cuDeviceGetAttribute: unsafe extern "C" fn(*mut c_int, CUdevice_attribute, CUdevice) -> CUDAStatus,
) -> Result<c_int, BackendError> {
let mut v = 0;
unsafe { cuDeviceGetAttribute(&raw mut v, attr, self.device) }.check(ErrorStatus::DeviceQuery)?;
Ok(v)
}
}
#[repr(C)]
#[derive(Debug, Copy, Clone)]
struct CUctx_st {
_unused: [u8; 0],
}
type CUcontext = *mut CUctx_st;
type CUdevice = c_int;
type CUdeviceptr = u64;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
struct CUmod_st {
_unused: [u8; 0],
}
type CUmodule = *mut CUmod_st;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
struct CUfunc_st {
_unused: [u8; 0],
}
type CUfunction = *mut CUfunc_st;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
struct CUstream_st {
_unused: [u8; 0],
}
type CUstream = *mut CUstream_st;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
struct CUevent_st {
_unused: [u8; 0],
}
type CUevent = *mut CUevent_st;
#[allow(unused)]
#[repr(u32)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
enum CUjit_option {
CU_JIT_MAX_REGISTERS = 0,
CU_JIT_THREADS_PER_BLOCK = 1,
CU_JIT_WALL_TIME = 2,
CU_JIT_INFO_LOG_BUFFER = 3,
CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES = 4,
CU_JIT_ERROR_LOG_BUFFER = 5,
CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES = 6,
CU_JIT_OPTIMIZATION_LEVEL = 7,
CU_JIT_TARGET_FROM_CUCONTEXT = 8,
CU_JIT_TARGET = 9,
CU_JIT_FALLBACK_STRATEGY = 10,
CU_JIT_GENERATE_DEBUG_INFO = 11,
CU_JIT_LOG_VERBOSE = 12,
CU_JIT_GENERATE_LINE_INFO = 13,
CU_JIT_CACHE_MODE = 14,
CU_JIT_NEW_SM3X_OPT = 15,
CU_JIT_FAST_COMPILE = 16,
CU_JIT_GLOBAL_SYMBOL_NAMES = 17,
CU_JIT_GLOBAL_SYMBOL_ADDRESSES = 18,
CU_JIT_GLOBAL_SYMBOL_COUNT = 19,
CU_JIT_NUM_OPTIONS = 20,
}
#[allow(unused)]
#[repr(u32)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
enum CUdevice_attribute {
CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK = 1,
CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X = 2,
CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y = 3,
CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z = 4,
CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X = 5,
CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y = 6,
CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z = 7,
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK = 8,
CU_DEVICE_ATTRIBUTE_TOTAL_CONSTANT_MEMORY = 9,
CU_DEVICE_ATTRIBUTE_WARP_SIZE = 10,
CU_DEVICE_ATTRIBUTE_MAX_PITCH = 11,
CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK = 12,
CU_DEVICE_ATTRIBUTE_CLOCK_RATE = 13,
CU_DEVICE_ATTRIBUTE_TEXTURE_ALIGNMENT = 14,
CU_DEVICE_ATTRIBUTE_GPU_OVERLAP = 15,
CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT = 16,
CU_DEVICE_ATTRIBUTE_KERNEL_EXEC_TIMEOUT = 17,
CU_DEVICE_ATTRIBUTE_INTEGRATED = 18,
CU_DEVICE_ATTRIBUTE_CAN_MAP_HOST_MEMORY = 19,
CU_DEVICE_ATTRIBUTE_COMPUTE_MODE = 20,
CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH = 21,
CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_WIDTH = 22,
CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_HEIGHT = 23,
CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH = 24,
CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT = 25,
CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH = 26,
CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_WIDTH = 27,
CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_HEIGHT = 28,
CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_LAYERS = 29,
CU_DEVICE_ATTRIBUTE_SURFACE_ALIGNMENT = 30,
CU_DEVICE_ATTRIBUTE_CONCURRENT_KERNELS = 31,
CU_DEVICE_ATTRIBUTE_ECC_ENABLED = 32,
CU_DEVICE_ATTRIBUTE_PCI_BUS_ID = 33,
CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID = 34,
CU_DEVICE_ATTRIBUTE_TCC_DRIVER = 35,
CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE = 36,
CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH = 37,
CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE = 38,
CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR = 39,
CU_DEVICE_ATTRIBUTE_ASYNC_ENGINE_COUNT = 40,
CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING = 41,
CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_WIDTH = 42,
CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_LAYERS = 43,
CU_DEVICE_ATTRIBUTE_CAN_TEX2D_GATHER = 44,
CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_GATHER_WIDTH = 45,
CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_GATHER_HEIGHT = 46,
CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH_ALTERNATE = 47,
CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT_ALTERNATE = 48,
CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH_ALTERNATE = 49,
CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID = 50,
CU_DEVICE_ATTRIBUTE_TEXTURE_PITCH_ALIGNMENT = 51,
CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_WIDTH = 52,
CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_WIDTH = 53,
CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_LAYERS = 54,
CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_WIDTH = 55,
CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_WIDTH = 56,
CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_HEIGHT = 57,
CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_WIDTH = 58,
CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_HEIGHT = 59,
CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_DEPTH = 60,
CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_WIDTH = 61,
CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_LAYERS = 62,
CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_WIDTH = 63,
CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_HEIGHT = 64,
CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_LAYERS = 65,
CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_WIDTH = 66,
CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_WIDTH = 67,
CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_LAYERS = 68,
CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_WIDTH = 70,
CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_HEIGHT = 71,
CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_PITCH = 72,
CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_WIDTH = 73,
CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_HEIGHT = 74,
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR = 75,
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR = 76,
CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_MIPMAPPED_WIDTH = 77,
CU_DEVICE_ATTRIBUTE_STREAM_PRIORITIES_SUPPORTED = 78,
CU_DEVICE_ATTRIBUTE_GLOBAL_L1_CACHE_SUPPORTED = 79,
CU_DEVICE_ATTRIBUTE_LOCAL_L1_CACHE_SUPPORTED = 80,
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR = 81,
CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_MULTIPROCESSOR = 82,
CU_DEVICE_ATTRIBUTE_MANAGED_MEMORY = 83,
CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD = 84,
CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD_GROUP_ID = 85,
CU_DEVICE_ATTRIBUTE_HOST_NATIVE_ATOMIC_SUPPORTED = 86,
CU_DEVICE_ATTRIBUTE_SINGLE_TO_DOUBLE_PRECISION_PERF_RATIO = 87,
CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS = 88,
CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS = 89,
CU_DEVICE_ATTRIBUTE_COMPUTE_PREEMPTION_SUPPORTED = 90,
CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM = 91,
CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH = 95,
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN = 97,
CU_DEVICE_ATTRIBUTE_CAN_FLUSH_REMOTE_WRITES = 98,
CU_DEVICE_ATTRIBUTE_HOST_REGISTER_SUPPORTED = 99,
CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES = 100,
CU_DEVICE_ATTRIBUTE_DIRECT_MANAGED_MEM_ACCESS_FROM_HOST = 101,
CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED = 102,
CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR_SUPPORTED = 103,
CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_HANDLE_SUPPORTED = 104,
CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_KMT_HANDLE_SUPPORTED = 105,
CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR = 106,
CU_DEVICE_ATTRIBUTE_GENERIC_COMPRESSION_SUPPORTED = 107,
CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE = 108,
CU_DEVICE_ATTRIBUTE_MAX_ACCESS_POLICY_WINDOW_SIZE = 109,
CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED = 110,
CU_DEVICE_ATTRIBUTE_RESERVED_SHARED_MEMORY_PER_BLOCK = 111,
CU_DEVICE_ATTRIBUTE_SPARSE_CUDA_ARRAY_SUPPORTED = 112,
CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED = 113,
CU_DEVICE_ATTRIBUTE_TIMELINE_SEMAPHORE_INTEROP_SUPPORTED = 114,
CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED = 115,
CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_SUPPORTED = 116,
CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_FLUSH_WRITES_OPTIONS = 117,
CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WRITES_ORDERING = 118,
CU_DEVICE_ATTRIBUTE_MEMPOOL_SUPPORTED_HANDLE_TYPES = 119,
CU_DEVICE_ATTRIBUTE_CLUSTER_LAUNCH = 120,
CU_DEVICE_ATTRIBUTE_DEFERRED_MAPPING_CUDA_ARRAY_SUPPORTED = 121,
CU_DEVICE_ATTRIBUTE_CAN_USE_64_BIT_STREAM_MEM_OPS = 122,
CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_WAIT_VALUE_NOR = 123,
CU_DEVICE_ATTRIBUTE_DMA_BUF_SUPPORTED = 124,
CU_DEVICE_ATTRIBUTE_IPC_EVENT_SUPPORTED = 125,
CU_DEVICE_ATTRIBUTE_MEM_SYNC_DOMAIN_COUNT = 126,
CU_DEVICE_ATTRIBUTE_TENSOR_MAP_ACCESS_SUPPORTED = 127,
CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED = 128,
CU_DEVICE_ATTRIBUTE_UNIFIED_FUNCTION_POINTERS = 129,
CU_DEVICE_ATTRIBUTE_NUMA_CONFIG = 130,
CU_DEVICE_ATTRIBUTE_NUMA_ID = 131,
CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED = 132,
CU_DEVICE_ATTRIBUTE_MPS_ENABLED = 133,
CU_DEVICE_ATTRIBUTE_HOST_NUMA_ID = 134,
CU_DEVICE_ATTRIBUTE_D3D12_CIG_SUPPORTED = 135,
CU_DEVICE_ATTRIBUTE_MAX,
}
impl DType {
pub(super) fn ptx(&self) -> &str {
match self {
Self::BF16 => todo!("BF16 is not native to OpenCL, workaround is WIP."),
Self::F16 => "f16",
Self::F32 => "f32",
Self::F64 => "f64",
Self::I8 => "s8",
Self::I16 => "s16",
Self::I32 => "s32",
Self::I64 => "s64",
Self::Bool => "pred",
Self::U8 => "u8",
Self::U16 => "u16",
Self::U32 => "u32",
Self::U64 => "u64",
}
}
}
impl Constant {
fn ptx(&self) -> String {
fn format_precise(val: impl std::fmt::Display, decimals: usize) -> String {
let s = format!("{:.*}", decimals, val);
let s = s.trim_end_matches('0').trim_end_matches('.');
if s.contains(".") { s.to_string() } else { format!("{s}.0") }
}
match self {
&Self::BF16(x) => format!("{}f", half::bf16::from_le_bytes(x)),
&Self::F16(x) => format!("__float2half({:.6})", half::f16::from_le_bytes(x)),
&Self::F32(x) => format!("{}", format_precise(f32::from_le_bytes(x), 9)),
&Self::F64(x) => format!("{}", format_precise(f64::from_le_bytes(x), 18)),
Self::U8(x) => format!("{x}"),
Self::I8(x) => format!("{x}"),
Self::I16(x) => format!("{x}"),
Self::U16(x) => format!("{x}"),
Self::U32(x) => format!("{x}U"),
&Self::U64(x) => format!("{}", u64::from_le_bytes(x)),
Self::I32(x) => format!("{x}"),
&Self::I64(x) => format!("{}", i64::from_le_bytes(x)),
&Self::Bool(x) => format!("{}", if x { 1 } else { 0 }),
}
}
}
#[allow(unused)]
#[repr(u32)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
enum CUDAStatus {
CUDA_SUCCESS = 0,
CUDA_ERROR_INVALID_VALUE = 1,
CUDA_ERROR_OUT_OF_MEMORY = 2,
CUDA_ERROR_NOT_INITIALIZED = 3,
CUDA_ERROR_DEINITIALIZED = 4,
CUDA_ERROR_PROFILER_DISABLED = 5,
CUDA_ERROR_PROFILER_NOT_INITIALIZED = 6,
CUDA_ERROR_PROFILER_ALREADY_STARTED = 7,
CUDA_ERROR_PROFILER_ALREADY_STOPPED = 8,
CUDA_ERROR_NO_DEVICE = 100,
CUDA_ERROR_INVALID_DEVICE = 101,
CUDA_ERROR_INVALID_IMAGE = 200,
CUDA_ERROR_INVALID_CONTEXT = 201,
CUDA_ERROR_CONTEXT_ALREADY_CURRENT = 202,
CUDA_ERROR_MAP_FAILED = 205,
CUDA_ERROR_UNMAP_FAILED = 206,
CUDA_ERROR_ARRAY_IS_MAPPED = 207,
CUDA_ERROR_ALREADY_MAPPED = 208,
CUDA_ERROR_NO_BINARY_FOR_GPU = 209,
CUDA_ERROR_ALREADY_ACQUIRED = 210,
CUDA_ERROR_NOT_MAPPED = 211,
CUDA_ERROR_NOT_MAPPED_AS_ARRAY = 212,
CUDA_ERROR_NOT_MAPPED_AS_POINTER = 213,
CUDA_ERROR_ECC_UNCORRECTABLE = 214,
CUDA_ERROR_UNSUPPORTED_LIMIT = 215,
CUDA_ERROR_CONTEXT_ALREADY_IN_USE = 216,
CUDA_ERROR_PEER_ACCESS_UNSUPPORTED = 217,
CUDA_ERROR_INVALID_PTX = 218,
CUDA_ERROR_INVALID_GRAPHICS_CONTEXT = 219,
CUDA_ERROR_NVLINK_UNCORRECTABLE = 220,
CUDA_ERROR_JIT_COMPILER_NOT_FOUND = 221,
CUDA_ERROR_INVALID_SOURCE = 300,
CUDA_ERROR_FILE_NOT_FOUND = 301,
CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND = 302,
CUDA_ERROR_SHARED_OBJECT_INIT_FAILED = 303,
CUDA_ERROR_OPERATING_SYSTEM = 304,
CUDA_ERROR_INVALID_HANDLE = 400,
CUDA_ERROR_ILLEGAL_STATE = 401,
CUDA_ERROR_NOT_FOUND = 500,
CUDA_ERROR_NOT_READY = 600,
CUDA_ERROR_ILLEGAL_ADDRESS = 700,
CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES = 701,
CUDA_ERROR_LAUNCH_TIMEOUT = 702,
CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING = 703,
CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED = 704,
CUDA_ERROR_PEER_ACCESS_NOT_ENABLED = 705,
CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE = 708,
CUDA_ERROR_CONTEXT_IS_DESTROYED = 709,
CUDA_ERROR_ASSERT = 710,
CUDA_ERROR_TOO_MANY_PEERS = 711,
CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED = 712,
CUDA_ERROR_HOST_MEMORY_NOT_REGISTERED = 713,
CUDA_ERROR_HARDWARE_STACK_ERROR = 714,
CUDA_ERROR_ILLEGAL_INSTRUCTION = 715,
CUDA_ERROR_MISALIGNED_ADDRESS = 716,
CUDA_ERROR_INVALID_ADDRESS_SPACE = 717,
CUDA_ERROR_INVALID_PC = 718,
CUDA_ERROR_LAUNCH_FAILED = 719,
CUDA_ERROR_COOPERATIVE_LAUNCH_TOO_LARGE = 720,
CUDA_ERROR_NOT_PERMITTED = 800,
CUDA_ERROR_NOT_SUPPORTED = 801,
CUDA_ERROR_SYSTEM_NOT_READY = 802,
CUDA_ERROR_SYSTEM_DRIVER_MISMATCH = 803,
CUDA_ERROR_COMPAT_NOT_SUPPORTED_ON_DEVICE = 804,
CUDA_ERROR_STREAM_CAPTURE_UNSUPPORTED = 900,
CUDA_ERROR_STREAM_CAPTURE_INVALIDATED = 901,
CUDA_ERROR_STREAM_CAPTURE_MERGE = 902,
CUDA_ERROR_STREAM_CAPTURE_UNMATCHED = 903,
CUDA_ERROR_STREAM_CAPTURE_UNJOINED = 904,
CUDA_ERROR_STREAM_CAPTURE_ISOLATION = 905,
CUDA_ERROR_STREAM_CAPTURE_IMPLICIT = 906,
CUDA_ERROR_CAPTURED_EVENT = 907,
CUDA_ERROR_STREAM_CAPTURE_WRONG_THREAD = 908,
CUDA_ERROR_TIMEOUT = 909,
CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE = 910,
CUDA_ERROR_UNKNOWN = 999,
}
impl CUDAStatus {
fn check(self, status: ErrorStatus) -> Result<(), BackendError> {
if self == Self::CUDA_SUCCESS {
Ok(())
} else {
Err(BackendError { status, context: format!("{self:?}").into() })
}
}
}
fn get_dtypes(kernel: &Kernel) -> (Map<OpId, u32>, Map<OpId, DType>) {
let mut rcs: Map<OpId, u32> = Map::with_capacity_and_hasher(kernel.ops.len().into(), BuildHasherDefault::new());
let mut dtypes: Map<OpId, DType> = Map::with_capacity_and_hasher(100, BuildHasherDefault::new());
let mut loop_ids = Vec::new();
for &op_id in &kernel.order {
match &kernel.ops[op_id] {
Op::ConstView { .. } | Op::StoreView { .. } | Op::LoadView { .. } => unreachable!(),
Op::Const(x) => {
dtypes.insert(op_id, x.dtype());
}
&Op::Define { dtype, .. } => {
dtypes.insert(op_id, dtype);
}
&Op::Load { src, index } => {
dtypes.insert(op_id, dtypes[&src]);
rcs.entry(index).and_modify(|rc| *rc += 1).or_insert(1);
}
&Op::Store { dst, x: src, index } => {
dtypes.insert(op_id, dtypes[&src]);
rcs.entry(dst).and_modify(|rc| *rc += 1).or_insert(1);
rcs.entry(src).and_modify(|rc| *rc += 1).or_insert(1);
rcs.entry(index).and_modify(|rc| *rc += 1).or_insert(1);
}
&Op::Cast { x, dtype } => {
dtypes.insert(op_id, dtype);
rcs.entry(x).and_modify(|rc| *rc += 1).or_insert(1);
}
&Op::Unary { x, .. } => {
dtypes.insert(op_id, dtypes[&x]);
rcs.entry(x).and_modify(|rc| *rc += 1).or_insert(1);
}
&Op::Binary { x, y, bop } => {
if matches!(bop, BOp::Cmpgt | BOp::Cmplt | BOp::NotEq | BOp::And | BOp::Or | BOp::Eq) {
dtypes.insert(op_id, DType::Bool);
} else {
dtypes.insert(op_id, dtypes[&x]);
}
rcs.entry(x).and_modify(|rc| *rc += 1).or_insert(1);
rcs.entry(y).and_modify(|rc| *rc += 1).or_insert(1);
}
Op::Loop { .. } => {
dtypes.insert(op_id, IDX_T);
loop_ids.push(op_id);
}
Op::EndLoop => {
let x = loop_ids.pop().unwrap();
rcs.entry(x).and_modify(|rc| *rc += 1).or_insert(1);
}
&Op::Reduce { x, .. } => {
dtypes.insert(op_id, dtypes[&x]);
rcs.entry(x).and_modify(|rc| *rc += 1).or_insert(1);
}
}
}
(rcs, dtypes)
}
struct Compiler {
var_map: Map<OpId, u16>,
loops: Vec<(Dim, u16, u16)>,
header: String,
body: String,
indent: String,
registers: Vec<(DType, u32)>,
scopes: Map<OpId, Scope>,
}
impl Compiler {
pub fn new() -> Self {
Self {
var_map: Map::default(),
loops: Vec::new(),
header: String::new(),
body: String::new(),
indent: format!(" "),
registers: Vec::new(),
scopes: Map::default(),
}
}
fn bop_to_ptx(&self, bop: BOp, dtype: DType) -> &'static str {
match bop {
BOp::Add => "add",
BOp::Sub => "sub",
BOp::Mul => {
if matches!(dtype, DType::F32 | DType::F64) {
"mul"
} else {
"mul.lo"
}
}
BOp::Div => {
if matches!(dtype, DType::F32 | DType::F64) {
"div.approx"
} else {
"div"
}
}
BOp::Pow => todo!(),
BOp::Mod => "rem",
BOp::Cmplt => "setp.lt",
BOp::Cmpgt => "setp.gt",
BOp::Maximum => "max",
BOp::Or => "or",
BOp::And => "and",
BOp::BitXor => "xor",
BOp::BitOr => "or",
BOp::BitAnd => "and",
BOp::BitShiftLeft => "shl.b32",
BOp::BitShiftRight => "shr.b32",
BOp::NotEq => "setp.ne",
BOp::Eq => "setp.eq",
}
}
fn uop_to_ptx(&self, uop: UOp) -> &'static str {
match uop {
UOp::Neg => "neg",
UOp::BitNot => "not",
UOp::Exp2 => "ex2.approx",
UOp::Log2 => "lg2.approx",
UOp::Reciprocal => "rcp.approx",
UOp::Sqrt => "sqrt.approx",
UOp::Sin => "sin.approx",
UOp::Cos => "cos.approx",
UOp::Floor => "floor.approx",
UOp::Trunc => "trunc.approx",
}
}
fn get_scope(&self, var: OpId) -> Scope {
if let Some(&scope) = self.scopes.get(&var) {
scope
} else {
Scope::Register
}
}
fn new_reg(&mut self, dtype: DType, rc: u32) -> u16 {
let mut i = 0;
while i < self.registers.len() {
let reg = &mut self.registers[i];
if reg.1 == 0 {
if reg.0 == dtype {
reg.1 = rc;
return i as u16;
}
}
i += 1;
}
debug_assert_eq!(i, self.registers.len());
self.registers.push((dtype, rc));
i as u16
}
fn new_var(&mut self, op_id: OpId, dtype: DType, rc: u32) -> u16 {
let i = self.new_reg(dtype, rc);
self.var_map.insert(op_id, i);
i
}
fn get_var(&mut self, x: OpId) -> u16 {
let x = self.var_map[&x];
self.registers[x as usize].1 -= 1;
x
}
fn release_reg(&mut self, x: u16) {
self.registers[x as usize].1 -= 1;
}
pub fn compile(
mut self,
kernel: &Kernel,
cc: [c_int; 2],
dev_info: &DeviceInfo,
debug: bool,
) -> Result<(Vec<u8>, Box<str>, Vec<usize>, Vec<usize>), BackendError> {
use std::fmt::Write;
let mut gws = Vec::new();
let mut lws = Vec::new();
for &op_id in &kernel.order {
if let &Op::Loop { dim, scope } = &kernel.ops[op_id] {
match scope {
Scope::Global => {
gws.push(dim);
}
Scope::Local => {
lws.push(dim);
}
Scope::Register => {}
}
}
}
if lws.iter().product::<usize>() > dev_info.max_local_threads {
return Err(BackendError { status: ErrorStatus::KernelCompilation, context: "Invalid local work size.".into() });
}
let name = format!(
"k_{}__{}",
gws.iter().map(ToString::to_string).collect::<Vec<_>>().join("_"),
lws.iter().map(ToString::to_string).collect::<Vec<_>>().join("_"),
)
.into_boxed_str();
_ = writeln!(
self.header,
".version {0}.{1}\n.target sm_{0}{1}\n.address_size 64\n.visible .entry {name}(",
cc[0], cc[1]
);
for &op_id in &kernel.order {
if let Op::Define { scope, .. } = kernel.ops[op_id] {
if scope == Scope::Global {
writeln!(self.header, "{}.param .u64 g{op_id},", self.indent).unwrap();
}
}
}
self.header.pop();
self.header.pop();
_ = writeln!(self.header, "\n) {{");
let mut loop_id_label_map = Map::default();
let mut label = 0;
let mut n_global_loops = 0;
let (rcs, dtypes) = get_dtypes(&kernel);
let mut loop_id = 0;
for &op_id in &kernel.order {
match &kernel.ops[op_id] {
Op::Define { dtype, scope, len, .. } => {
self.scopes.insert(op_id, *scope);
match scope {
Scope::Global => {
_ = writeln!(self.body, "{}ld.param.u64 %p{op_id}, [g{op_id}];", self.indent);
}
Scope::Local => todo!(),
Scope::Register => {
_ = writeln!(
self.body,
"{}.local .align {} .{} %p{op_id}[{len}];",
self.indent,
dtype.bit_size() / 8,
dtype.ptx()
);
}
};
}
Op::Const(constant) => {
let reg = self.new_var(op_id, constant.dtype(), u32::MAX); _ = writeln!(
self.body,
"{}mov.{} %r{reg}, {};",
self.indent,
constant.dtype().ptx(),
constant.ptx()
);
}
Op::Load { src, index } => {
let dtype = dtypes[&src];
match self.get_scope(*src) {
Scope::Global => {
let byte_shift = (dtype.bit_size() / 8).ilog2();
let idx = self.get_var(*index);
let offset = self.new_reg(DType::U64, 1);
let reg = self.new_var(op_id, dtype, rcs[&op_id]);
if IDX_T == DType::U64 {
if offset != idx {
_ = writeln!(self.body, "{}mov.u64 %r{offset}, %r{idx};", self.indent);
}
} else {
_ = writeln!(self.body, "{}cvt.u64.u32 %r{offset}, %r{idx};", self.indent);
}
_ = writeln!(self.body, "{}shl.b64 %r{offset}, %r{offset}, {byte_shift};", self.indent);
_ = writeln!(self.body, "{}add.u64 %address, %p{src}, %r{offset};", self.indent);
let scope = match self.get_scope(*src) {
Scope::Global => "global",
Scope::Local => "shared",
Scope::Register => "local",
};
self.release_reg(offset);
_ = writeln!(self.body, "{}ld.{}.{} %r{reg}, [%address];", self.indent, scope, dtype.ptx());
}
Scope::Local => todo!(),
Scope::Register => {
let idx = self.get_var(*index);
let reg = self.new_var(op_id, dtype, rcs[&op_id]);
_ = writeln!(
self.body,
"{}ld.local.{} %r{reg}, [%p{src} + %r{idx}];",
self.indent,
dtype.ptx()
);
}
}
}
Op::Store { dst, x, index } => {
let dtype = dtypes[x];
let byte_shift = (dtype.bit_size() / 8).ilog2();
let offset = self.new_reg(DType::U64, 1);
match self.get_scope(*dst) {
Scope::Global => {
if dtype == DType::Bool {
let gstu = self.new_reg(DType::U32, 1);
let idx = self.get_var(*index);
let x = self.get_var(*x);
_ = writeln!(self.body, "{}selp.u32 %r{gstu}, 1, 0, %r{x};", self.indent);
if IDX_T == DType::U64 {
if offset != idx {
_ = writeln!(self.body, "{}mov.u64 %r{offset}, %r{idx};", self.indent);
}
} else {
_ = writeln!(self.body, "{}cvt.u64.u32 %r{offset}, %r{idx};", self.indent);
}
_ = writeln!(self.body, "{}add.u64 %address, %p{dst}, %r{offset};", self.indent);
_ = writeln!(self.body, "{}st.global.u8 [%address], %r{gstu};", self.indent);
self.release_reg(gstu);
} else {
let idx = self.get_var(*index);
let x = self.get_var(*x);
if IDX_T == DType::U64 {
if offset != idx {
_ = writeln!(self.body, "{}mov.u64 %r{offset}, %r{idx};", self.indent);
}
} else {
_ = writeln!(self.body, "{}cvt.u64.u32 %r{offset}, %r{idx};", self.indent);
}
_ = writeln!(self.body, "{}shl.b64 %r{offset}, %r{offset}, {byte_shift};", self.indent);
_ = writeln!(self.body, "{}add.u64 %address, %p{dst}, %r{offset};", self.indent);
_ = writeln!(self.body, "{}st.global.{} [%address], %r{x};", self.indent, dtype.ptx());
}
}
Scope::Local => {
todo!();
}
Scope::Register => {
let idx = self.get_var(*index);
let x = self.get_var(*x);
_ = writeln!(
self.body,
"{}st.local.{} [%p{dst} + %r{idx}], %r{x};",
self.indent,
dtype.ptx()
);
}
}
self.release_reg(offset);
}
&Op::Cast { x, dtype } => {
let xdtype = dtypes[&x];
let x = self.get_var(x);
let reg = self.new_var(op_id, dtype, rcs[&op_id]);
match (dtype, xdtype) {
(DType::Bool, _) => {
if dtype.is_float() {
_ = writeln!(self.body, "{}setp.ne.{} %r{reg}, %r{x}, 0.0;", self.indent, xdtype.ptx());
} else {
_ = writeln!(self.body, "{}setp.ne.{} %r{reg}, %r{x}, 0;", self.indent, xdtype.ptx());
}
}
(_, DType::Bool) => {
if dtype == DType::F64 {
_ = writeln!(self.body, "{}selp.{} %r{reg}, 1.0, 0.0, %r{x};", self.indent, dtype.ptx(),);
} else if dtype == DType::F32 {
let a = self.new_reg(DType::F32, 0);
let b = self.new_reg(DType::F32, 0);
_ = writeln!(self.body, "{}selp.{} %r{reg}, %r{a}, %r{b}, %r{x};", self.indent, dtype.ptx(),);
} else {
_ = writeln!(self.body, "{}selp.{} %r{reg}, 1, 0, %r{x};", self.indent, dtype.ptx());
}
}
(DType::I32, DType::F32) => {
_ = writeln!(
self.body,
"{}cvt.rni.{}.{} %r{reg}, %r{x};",
self.indent,
dtype.ptx(),
xdtype.ptx()
);
}
(_, _) => {
_ = writeln!(
self.body,
"{}cvt.rn.{}.{} %r{reg}, %r{x};",
self.indent,
dtype.ptx(),
xdtype.ptx()
);
}
}
}
Op::Unary { x, uop } => {
let dtype = dtypes[&x];
let x = self.get_var(*x);
let reg = self.new_var(op_id, dtype, rcs[&op_id]);
_ = writeln!(
self.body,
"{}{}.{} %r{reg}, %r{x};",
self.indent,
self.uop_to_ptx(*uop),
dtype.ptx()
);
}
Op::Binary { x, y, bop } => {
let dtype = dtypes[&op_id];
let xr = self.get_var(*x);
let yr = self.get_var(*y);
let reg = self.new_var(op_id, dtype, rcs[&op_id]);
_ = writeln!(
self.body,
"{}{}.{} %r{reg}, %r{xr}, %r{yr};",
self.indent,
self.bop_to_ptx(*bop, dtype),
dtypes[x].ptx(),
);
}
&Op::Loop { dim, scope } => {
let loop_idx = self.new_var(op_id, IDX_T, rcs[&op_id]); match scope {
Scope::Global => {
_ = writeln!(
self.body,
"{}{}.u32 %r{loop_idx}, %ctaid.{};",
self.indent,
if IDX_T == DType::U64 { "cvt.u64" } else { "mov" },
match loop_id {
0 => "x",
1 => "y",
2 => "z",
_ => unreachable!(),
}
);
n_global_loops += 1;
}
Scope::Local => {
_ = writeln!(
self.body,
"{}{}.u32 %r{loop_idx}, %tid.{};",
self.indent,
if IDX_T == DType::U64 { "cvt.u64" } else { "mov" },
match loop_id - n_global_loops {
0 => "x",
1 => "y",
2 => "z",
_ => unreachable!(),
}
);
}
Scope::Register => {
let loop_pred = self.new_reg(DType::Bool, 2); self.loops.push((dim, loop_pred, loop_idx));
_ = writeln!(self.body, "{}mov.{} %r{loop_idx}, 0;", self.indent, IDX_T.ptx());
_ = writeln!(self.body, "{}LOOP_{label}:", self.indent);
loop_id_label_map.insert(loop_id, label);
label += 1;
self.indent += " ";
}
}
loop_id += 1;
}
Op::EndLoop => {
loop_id -= 1;
if let Some((dim, loop_pred, loop_idx)) = self.loops.pop() {
_ = writeln!(self.body, "{}add.{} %r{loop_idx}, %r{loop_idx}, 1;", self.indent, IDX_T.ptx());
writeln!(
self.body,
"{}setp.lt.{} %r{loop_pred}, %r{loop_idx}, {};",
self.indent,
IDX_T.ptx(),
Constant::idx(dim as u64).ptx()
)
.unwrap();
_ = writeln!(
self.body,
"{}@%r{loop_pred} bra LOOP_{};",
self.indent, loop_id_label_map[&loop_id]
);
self.indent.pop();
self.indent.pop();
}
}
Op::ConstView { .. } | Op::LoadView { .. } | Op::StoreView { .. } | Op::Reduce { .. } => {
unreachable!()
}
}
}
_ = writeln!(self.body, "{}ret;\n}}", self.indent);
for &op_id in &kernel.order {
if let Op::Define { scope, .. } = kernel.ops[op_id] {
if scope == Scope::Global {
_ = writeln!(self.header, "{}.reg .s64 %p{op_id};", self.indent);
}
}
}
_ = writeln!(self.header, "{}.reg .u64 %address;", self.indent);
for (op_id, (dtype, _)) in self.registers.iter().enumerate() {
_ = writeln!(self.header, "{}.reg .{} %r{op_id};", self.indent, dtype.ptx());
}
self.header.push_str(&self.body);
if debug {
println!("{}", self.header);
}
Ok((self.header.into_bytes(), name, gws, lws))
}
}