use std::mem::MaybeUninit;
use std::ptr::null_mut;
use bytemuck::cast_slice;
use crate::bindings::{
cublasCreate_v2, cublasDestroy_v2, cublasHandle_t, cublasLtCreate, cublasLtDestroy, cublasLtHandle_t,
cublasSetStream_v2, cudaDeviceAttr, cudaDeviceGetAttribute, cudaDeviceProp, cudaEventRecord, cudaGetDevice,
cudaGetDeviceCount, cudaSetDevice, cudaStreamBeginCapture, cudaStreamCaptureMode,
cudaStreamCreate, cudaStreamDestroy, cudaStreamEndCapture, cudaStreamSynchronize, cudaStreamWaitEvent,
cudaStream_t, cudnnCreate, cudnnDestroy, cudnnHandle_t, cudnnSetStream,
};
use crate::wrapper::event::CudaEvent;
use crate::wrapper::graph::CudaGraph;
use crate::wrapper::mem::device::DevicePtr;
use crate::wrapper::status::Status;
use crate::bindings::cudaGetDeviceProperties_v2 as cudaGetDeviceProperties;
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub struct Device(i32);
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub struct ComputeCapability {
pub major: i32,
pub minor: i32,
}
impl Device {
pub fn new(device: i32) -> Self {
assert!(
0 <= device && device < cuda_device_count(),
"Device with id {} doesn't exist",
device
);
Device(device)
}
pub fn all() -> impl Iterator<Item = Self> {
(0..cuda_device_count()).map(Device::new)
}
pub fn current() -> Device {
unsafe {
let mut inner = 0;
cudaGetDevice(&mut inner as *mut _).unwrap();
Device::new(inner)
}
}
pub fn inner(self) -> i32 {
self.0
}
pub fn switch_to(self) {
unsafe { cudaSetDevice(self.inner()).unwrap() }
}
pub fn alloc(self, len_bytes: usize) -> DevicePtr {
DevicePtr::alloc(self, len_bytes)
}
pub fn properties(self) -> cudaDeviceProp {
unsafe {
self.switch_to();
let mut properties = MaybeUninit::uninit();
cudaGetDeviceProperties(properties.as_mut_ptr(), self.inner()).unwrap();
properties.assume_init()
}
}
pub fn attribute(self, attribute: cudaDeviceAttr) -> i32 {
unsafe {
let mut value: i32 = 0;
cudaDeviceGetAttribute(&mut value as *mut _, attribute, self.inner()).unwrap();
value
}
}
pub fn compute_capability(self) -> ComputeCapability {
ComputeCapability {
major: self.attribute(cudaDeviceAttr::cudaDevAttrComputeCapabilityMajor),
minor: self.attribute(cudaDeviceAttr::cudaDevAttrComputeCapabilityMinor),
}
}
pub fn name(self) -> String {
let properties = self.properties();
let name = &properties.name;
let len = name.iter().position(|&c| c == 0).unwrap_or(name.len());
std::str::from_utf8(cast_slice::<i8, u8>(&name[..len]))
.unwrap()
.to_owned()
}
}
fn cuda_device_count() -> i32 {
unsafe {
let mut count = 0;
cudaGetDeviceCount(&mut count as *mut _).unwrap();
count
}
}
#[derive(Debug)]
pub struct CudaStream {
device: Device,
inner: cudaStream_t,
}
impl Drop for CudaStream {
fn drop(&mut self) {
unsafe {
cudaStreamDestroy(self.inner).unwrap_in_drop();
}
}
}
impl CudaStream {
pub fn new(device: Device) -> Self {
unsafe {
let mut inner = null_mut();
device.switch_to();
cudaStreamCreate(&mut inner as *mut _).unwrap();
CudaStream { device, inner }
}
}
pub fn synchronize(&self) {
unsafe { cudaStreamSynchronize(self.inner()).unwrap() }
}
pub fn device(&self) -> Device {
self.device
}
pub unsafe fn inner(&self) -> cudaStream_t {
self.inner
}
pub fn record_event(&self) -> CudaEvent {
let event = CudaEvent::new();
self.record_existing_event(&event);
event
}
pub fn record_existing_event(&self, event: &CudaEvent) {
unsafe { cudaEventRecord(event.inner(), self.inner()).unwrap() }
}
pub fn wait_for_event(&self, event: &CudaEvent) {
unsafe {
cudaStreamWaitEvent(self.inner, event.inner(), 0).unwrap();
}
}
pub unsafe fn begin_capture(&self) {
cudaStreamBeginCapture(self.inner(), cudaStreamCaptureMode::cudaStreamCaptureModeGlobal).unwrap()
}
pub unsafe fn end_capture(&self) -> CudaGraph {
let mut graph = null_mut();
cudaStreamEndCapture(self.inner(), &mut graph as *mut _).unwrap();
CudaGraph::new_from_inner(graph)
}
}
#[derive(Debug)]
pub struct CudnnHandle {
inner: cudnnHandle_t,
stream: CudaStream,
}
impl Drop for CudnnHandle {
fn drop(&mut self) {
unsafe {
self.device().switch_to();
cudnnDestroy(self.inner).unwrap_in_drop()
}
}
}
impl CudnnHandle {
pub fn new(device: Device) -> Self {
CudnnHandle::new_with_stream(CudaStream::new(device))
}
pub fn new_with_stream(stream: CudaStream) -> Self {
unsafe {
let mut inner = null_mut();
stream.device.switch_to();
cudnnCreate(&mut inner as *mut _).unwrap();
cudnnSetStream(inner, stream.inner()).unwrap();
CudnnHandle { inner, stream }
}
}
pub fn device(&self) -> Device {
self.stream.device()
}
pub fn stream(&self) -> &CudaStream {
&self.stream
}
pub unsafe fn inner(&self) -> cudnnHandle_t {
self.inner
}
}
#[derive(Debug)]
pub struct CublasHandle {
inner: cublasHandle_t,
stream: CudaStream,
}
impl Drop for CublasHandle {
fn drop(&mut self) {
unsafe { cublasDestroy_v2(self.inner).unwrap_in_drop() }
}
}
impl CublasHandle {
pub fn new(device: Device) -> Self {
CublasHandle::new_with_stream(CudaStream::new(device))
}
pub fn new_with_stream(stream: CudaStream) -> Self {
unsafe {
let mut inner = null_mut();
stream.device.switch_to();
cublasCreate_v2(&mut inner as *mut _).unwrap();
cublasSetStream_v2(inner, stream.inner()).unwrap();
CublasHandle { inner, stream }
}
}
pub fn stream(&self) -> &CudaStream {
&self.stream
}
pub unsafe fn inner(&self) -> cublasHandle_t {
self.inner
}
}
#[derive(Debug)]
pub struct CublasLtHandle {
inner: cublasLtHandle_t,
}
impl Drop for CublasLtHandle {
fn drop(&mut self) {
unsafe { cublasLtDestroy(self.inner).unwrap_in_drop() }
}
}
impl CublasLtHandle {
pub fn new(device: Device) -> Self {
unsafe {
let mut inner = null_mut();
device.switch_to();
cublasLtCreate(&mut inner as *mut _).unwrap();
CublasLtHandle { inner }
}
}
pub unsafe fn inner(&self) -> cublasLtHandle_t {
self.inner
}
}