use crate::backend::{BackendDevice, BackendStorage};
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
pub use candle_kernels as kernels;
pub use cudarc;
use cudarc::driver::CudaFunction;
use float8::F8E4M3;
use half::{bf16, f16};
use std::collections::HashMap;
use std::sync::{Arc, Mutex, RwLock};
use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct DeviceId(usize);
impl DeviceId {
fn new() -> Self {
use std::sync::atomic;
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
}
}
struct CudaRng(cudarc::curand::CudaRng);
unsafe impl Send for CudaRng {}
pub struct ModuleStore {
mdls: [Option<Arc<cudarc::driver::CudaModule>>; kernels::ALL_IDS.len()],
}
#[derive(Clone)]
pub struct CudaDevice {
id: DeviceId,
context: Arc<cudarc::driver::CudaContext>,
modules: Arc<std::sync::RwLock<ModuleStore>>,
custom_modules: Arc<std::sync::RwLock<HashMap<String, Arc<cudarc::driver::CudaModule>>>>,
stream: Arc<cudarc::driver::CudaStream>,
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
curand: Arc<Mutex<CudaRng>>,
seed_value: Arc<RwLock<u64>>,
}
impl std::fmt::Debug for CudaDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "CudaDevice({:?})", self.id)
}
}
impl CudaDevice {
#[allow(clippy::missing_safety_doc)]
pub unsafe fn alloc<T: cudarc::driver::DeviceRepr>(
&self,
len: usize,
) -> Result<cudarc::driver::CudaSlice<T>> {
self.stream.alloc::<T>(len).w()
}
pub fn alloc_zeros<T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits>(
&self,
len: usize,
) -> Result<cudarc::driver::CudaSlice<T>> {
self.stream.alloc_zeros::<T>(len).w()
}
pub fn memcpy_htod<
T: cudarc::driver::DeviceRepr,
Src: cudarc::driver::HostSlice<T> + ?Sized,
Dst: cudarc::driver::DevicePtrMut<T>,
>(
&self,
src: &Src,
dst: &mut Dst,
) -> Result<()> {
self.stream.memcpy_htod(src, dst).w()
}
pub fn clone_dtoh<T: cudarc::driver::DeviceRepr, Src: cudarc::driver::DevicePtr<T>>(
&self,
src: &Src,
) -> Result<Vec<T>> {
self.stream.clone_dtoh(src).w()
}
pub fn memcpy_dtod<
T,
Src: cudarc::driver::DevicePtr<T>,
Dst: cudarc::driver::DevicePtrMut<T>,
>(
&self,
src: &Src,
dst: &mut Dst,
) -> Result<()> {
self.stream.memcpy_dtod(src, dst).w()
}
pub fn memcpy_dtoh<
T: cudarc::driver::DeviceRepr,
Src: cudarc::driver::DevicePtr<T>,
Dst: cudarc::driver::HostSlice<T>,
>(
&self,
src: &Src,
dst: &mut Dst,
) -> Result<()> {
self.stream.memcpy_dtoh(src, dst).w()
}
pub fn clone_htod<T: cudarc::driver::DeviceRepr, Src: cudarc::driver::HostSlice<T> + ?Sized>(
&self,
src: &Src,
) -> Result<cudarc::driver::CudaSlice<T>> {
self.stream.clone_htod(src).w()
}
}
pub struct CudaFunc {
func: CudaFunction,
stream: Arc<cudarc::driver::CudaStream>,
}
impl std::ops::Deref for CudaFunc {
type Target = CudaFunction;
fn deref(&self) -> &Self::Target {
&self.func
}
}
impl CudaFunc {
pub fn into_cuda_function(self) -> CudaFunction {
self.func
}
}
#[macro_export]
macro_rules! builder_arg {
($b:ident, $($arg:expr),*) => {
$(
let __arg = $arg;
$b.arg(&__arg);
)*
};
}
impl CudaFunc {
pub fn builder(&self) -> cudarc::driver::LaunchArgs<'_> {
self.stream.launch_builder(&self.func)
}
}
impl CudaDevice {
pub fn cuda_stream(&self) -> Arc<cudarc::driver::CudaStream> {
self.stream.clone()
}
pub unsafe fn disable_event_tracking(&self) {
self.context.disable_event_tracking()
}
pub fn is_event_tracking(&self) -> bool {
self.context.is_event_tracking()
}
#[cfg(all(feature = "ug", not(target_arch = "wasm32")))]
pub fn compile(
&self,
func_name: &'static str,
kernel: candle_ug::lang::ssa::Kernel,
) -> Result<CudaFunc> {
let mut buf = vec![];
candle_ug::cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
let cuda_code = String::from_utf8(buf)?;
let opts = cudarc::nvrtc::CompileOptions {
use_fast_math: Some(true),
..Default::default()
};
let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?;
let module = self.context.load_module(ptx).w()?;
let func = module.load_function(func_name).w()?;
Ok(CudaFunc {
func,
stream: self.stream.clone(),
})
}
pub fn id(&self) -> DeviceId {
self.id
}
pub fn get_or_load_custom_func(
&self,
fn_name: &str,
module_name: &str,
ptx: &str,
) -> Result<CudaFunc> {
let ms = self.custom_modules.read().unwrap();
if let Some(mdl) = ms.get(module_name).as_ref() {
let func = mdl.load_function(fn_name).w()?;
return Ok(CudaFunc {
func,
stream: self.stream.clone(),
});
}
drop(ms);
let mut ms = self.custom_modules.write().unwrap();
let cuda_module = self.context.load_module(ptx.into()).w()?;
ms.insert(module_name.to_string(), cuda_module.clone());
let func = cuda_module.load_function(fn_name).w()?;
Ok(CudaFunc {
func,
stream: self.stream.clone(),
})
}
pub fn get_or_load_func(&self, fn_name: &str, mdl: &kernels::Module) -> Result<CudaFunc> {
let ms = self.modules.read().unwrap();
if let Some(mdl) = ms.mdls[mdl.index()].as_ref() {
let func = mdl.load_function(fn_name).w()?;
return Ok(CudaFunc {
func,
stream: self.stream.clone(),
});
}
drop(ms);
let mut ms = self.modules.write().unwrap();
let cuda_module = self.context.load_module(mdl.ptx().into()).w()?;
ms.mdls[mdl.index()] = Some(cuda_module.clone());
let func = cuda_module.load_function(fn_name).w()?;
Ok(CudaFunc {
func,
stream: self.stream.clone(),
})
}
pub fn cublas_handle(&self) -> Arc<cudarc::cublas::CudaBlas> {
self.blas.clone()
}
}
impl CudaDevice {
pub fn new_with_stream(ordinal: usize) -> Result<Self> {
let context = cudarc::driver::CudaContext::new(ordinal).w()?;
let stream = context.new_stream().w()?;
let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?;
let module_store = ModuleStore {
mdls: [const { None }; kernels::ALL_IDS.len()],
};
Ok(Self {
id: DeviceId::new(),
context,
stream,
blas: Arc::new(blas),
curand: Arc::new(Mutex::new(CudaRng(curand))),
modules: Arc::new(std::sync::RwLock::new(module_store)),
custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())),
seed_value: Arc::new(RwLock::new(299792458)),
})
}
}
impl BackendDevice for CudaDevice {
type Storage = CudaStorage;
fn new(ordinal: usize) -> Result<Self> {
let context = cudarc::driver::CudaContext::new(ordinal).w()?;
let stream = context.default_stream();
let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?;
let module_store = ModuleStore {
mdls: [const { None }; kernels::ALL_IDS.len()],
};
Ok(Self {
id: DeviceId::new(),
context,
stream,
blas: Arc::new(blas),
curand: Arc::new(Mutex::new(CudaRng(curand))),
modules: Arc::new(std::sync::RwLock::new(module_store)),
custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())),
seed_value: Arc::new(RwLock::new(299792458)),
})
}
fn set_seed(&self, seed: u64) -> Result<()> {
let mut curand = self.curand.lock().unwrap();
curand.0 = cudarc::curand::CudaRng::new(seed, self.stream.clone()).w()?;
*self.seed_value.write().unwrap() = seed;
Ok(())
}
fn get_current_seed(&self) -> Result<u64> {
Ok(*self.seed_value.read().unwrap())
}
fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Cuda {
gpu_id: self.context.ordinal(),
}
}
fn same_device(&self, rhs: &Self) -> bool {
self.id == rhs.id
}
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let slice = match dtype {
DType::U8 => {
let data = self.alloc_zeros::<u8>(elem_count)?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
let data = self.alloc_zeros::<u32>(elem_count)?;
CudaStorageSlice::U32(data)
}
DType::I16 => {
let data = self.alloc_zeros::<i16>(elem_count)?;
CudaStorageSlice::I16(data)
}
DType::I32 => {
let data = self.alloc_zeros::<i32>(elem_count)?;
CudaStorageSlice::I32(data)
}
DType::I64 => {
let data = self.alloc_zeros::<i64>(elem_count)?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
let data = self.alloc_zeros::<bf16>(elem_count)?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
let data = self.alloc_zeros::<f16>(elem_count)?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
let data = self.alloc_zeros::<f32>(elem_count)?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let data = self.alloc_zeros::<f64>(elem_count)?;
CudaStorageSlice::F64(data)
}
DType::F8E4M3 => {
let data = self.alloc_zeros::<F8E4M3>(elem_count)?;
CudaStorageSlice::F8E4M3(data)
}
DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
return Err(
CudaError::InternalError("Dummy types not supported in CUDA backend").into(),
)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn rand_uniform(&self, shape: &Shape, dtype: DType, lo: f64, up: f64) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
let slice = match dtype {
DType::U8
| DType::U32
| DType::I16
| DType::I32
| DType::I64
| DType::F16
| DType::BF16 => Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_uniform",
})
.w()?,
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count)? };
curand.0.fill_with_uniform(&mut data).w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count)? };
curand.0.fill_with_uniform(&mut data).w()?;
CudaStorageSlice::F64(data)
}
DType::F8E4M3 | DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_uniform",
})
.w()?
}
};
let slice = if lo == 0. && up == 1.0 {
slice
} else {
use super::utils::Map1;
let layout = Layout::contiguous(shape);
super::Affine(up - lo, lo).map(&slice, self, &layout)?
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
let elem_count_round = if elem_count % 2 == 1 {
elem_count + 1
} else {
elem_count
};
let slice = match dtype {
DType::U8
| DType::U32
| DType::I16
| DType::I32
| DType::I64
| DType::F16
| DType::BF16 => Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_normal",
})
.w()?,
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count_round)? };
curand
.0
.fill_with_normal(&mut data, mean as f32, std as f32)
.w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count_round)? };
curand.0.fill_with_normal(&mut data, mean, std).w()?;
CudaStorageSlice::F64(data)
}
DType::F8E4M3 | DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_normal",
})
.w()?
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
let elem_count = shape.elem_count();
let slice = match dtype {
DType::U8 => {
let data = self.alloc::<u8>(elem_count)?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
let data = self.alloc::<u32>(elem_count)?;
CudaStorageSlice::U32(data)
}
DType::I16 => {
let data = self.alloc::<i16>(elem_count)?;
CudaStorageSlice::I16(data)
}
DType::I32 => {
let data = self.alloc::<i32>(elem_count)?;
CudaStorageSlice::I32(data)
}
DType::I64 => {
let data = self.alloc::<i64>(elem_count)?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
let data = self.alloc::<bf16>(elem_count)?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
let data = self.alloc::<f16>(elem_count)?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
let data = self.alloc::<f32>(elem_count)?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let data = self.alloc::<f64>(elem_count)?;
CudaStorageSlice::F64(data)
}
DType::F8E4M3 => {
let data = self.alloc::<F8E4M3>(elem_count)?;
CudaStorageSlice::F8E4M3(data)
}
DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
return Err(
CudaError::InternalError("Dummy types not supported in CUDA backend").into(),
)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
let slice = match T::cpu_storage_ref(s) {
CpuStorageRef::U8(storage) => {
let data = self.clone_htod(storage)?;
CudaStorageSlice::U8(data)
}
CpuStorageRef::U32(storage) => {
let data = self.clone_htod(storage)?;
CudaStorageSlice::U32(data)
}
CpuStorageRef::I16(storage) => {
let data = self.clone_htod(storage)?;
CudaStorageSlice::I16(data)
}
CpuStorageRef::I32(storage) => {
let data = self.clone_htod(storage)?;
CudaStorageSlice::I32(data)
}
CpuStorageRef::I64(storage) => {
let data = self.clone_htod(storage)?;
CudaStorageSlice::I64(data)
}
CpuStorageRef::BF16(storage) => {
let data = self.clone_htod(storage)?;
CudaStorageSlice::BF16(data)
}
CpuStorageRef::F16(storage) => {
let data = self.clone_htod(storage)?;
CudaStorageSlice::F16(data)
}
CpuStorageRef::F32(storage) => {
let data = self.clone_htod(storage)?;
CudaStorageSlice::F32(data)
}
CpuStorageRef::F64(storage) => {
let data = self.clone_htod(storage)?;
CudaStorageSlice::F64(data)
}
CpuStorageRef::F8E4M3(storage) => {
let data = self.clone_htod(storage)?;
CudaStorageSlice::F8E4M3(data)
}
CpuStorageRef::F4(_)
| CpuStorageRef::F6E2M3(_)
| CpuStorageRef::F6E3M2(_)
| CpuStorageRef::F8E8M0(_) => {
return Err(CudaError::UnsupportedDtype {
dtype: T::DTYPE,
op: "storage_from_slice",
}
.into());
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {
let data = self.clone_htod(storage)?;
CudaStorageSlice::U8(data)
}
CpuStorage::U32(storage) => {
let data = self.clone_htod(storage)?;
CudaStorageSlice::U32(data)
}
CpuStorage::I16(storage) => {
let data = self.clone_htod(storage)?;
CudaStorageSlice::I16(data)
}
CpuStorage::I32(storage) => {
let data = self.clone_htod(storage)?;
CudaStorageSlice::I32(data)
}
CpuStorage::I64(storage) => {
let data = self.clone_htod(storage)?;
CudaStorageSlice::I64(data)
}
CpuStorage::BF16(storage) => {
let data = self.clone_htod(storage)?;
CudaStorageSlice::BF16(data)
}
CpuStorage::F16(storage) => {
let data = self.clone_htod(storage)?;
CudaStorageSlice::F16(data)
}
CpuStorage::F32(storage) => {
let data = self.clone_htod(storage)?;
CudaStorageSlice::F32(data)
}
CpuStorage::F64(storage) => {
let data = self.clone_htod(storage)?;
CudaStorageSlice::F64(data)
}
CpuStorage::F8E4M3(storage) => {
let data = self.clone_htod(storage)?;
CudaStorageSlice::F8E4M3(data)
}
CpuStorage::F4(_)
| CpuStorage::F6E2M3(_)
| CpuStorage::F6E3M2(_)
| CpuStorage::F8E8M0(_) => {
return Err(CudaError::UnsupportedDtype {
dtype: storage.dtype(),
op: "storage_from_cpu_storage",
}
.into());
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {
let data = self.clone_htod(&storage)?;
CudaStorageSlice::U8(data)
}
CpuStorage::U32(storage) => {
let data = self.clone_htod(&storage)?;
CudaStorageSlice::U32(data)
}
CpuStorage::I16(storage) => {
let data = self.clone_htod(&storage)?;
CudaStorageSlice::I16(data)
}
CpuStorage::I32(storage) => {
let data = self.clone_htod(&storage)?;
CudaStorageSlice::I32(data)
}
CpuStorage::I64(storage) => {
let data = self.clone_htod(&storage)?;
CudaStorageSlice::I64(data)
}
CpuStorage::BF16(storage) => {
let data = self.clone_htod(&storage)?;
CudaStorageSlice::BF16(data)
}
CpuStorage::F16(storage) => {
let data = self.clone_htod(&storage)?;
CudaStorageSlice::F16(data)
}
CpuStorage::F32(storage) => {
let data = self.clone_htod(&storage)?;
CudaStorageSlice::F32(data)
}
CpuStorage::F64(storage) => {
let data = self.clone_htod(&storage)?;
CudaStorageSlice::F64(data)
}
CpuStorage::F8E4M3(storage) => {
let data = self.clone_htod(&storage)?;
CudaStorageSlice::F8E4M3(data)
}
CpuStorage::F4(_)
| CpuStorage::F6E2M3(_)
| CpuStorage::F6E3M2(_)
| CpuStorage::F8E8M0(_) => {
return Err(CudaError::UnsupportedDtype {
dtype: storage.dtype(),
op: "storage_from_cpu_storage_owned",
}
.into());
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn synchronize(&self) -> Result<()> {
self.stream.synchronize().map_err(crate::Error::wrap)?;
Ok(())
}
}