use std::sync::Arc;
use cudarc::cublas::sys::cublasOperation_t;
use cudarc::cublas::{CudaBlas, Gemm, GemmConfig};
use crate::error::GpuError;
use crate::gpu_ref::GpuRef;
#[cfg(feature = "curand")]
use cudarc::curand::CudaRng;
#[cfg(feature = "cufft")]
use cudarc::cufft::CudaFft;
pub trait RecordMode {
type Op;
fn enqueue_record(
&mut self,
stream: &Arc<cudarc::driver::CudaStream>,
op: Self::Op,
) -> Result<(), GpuError>;
}
pub struct BlasSgemmOp {
pub a: GpuRef<f32>,
pub b: GpuRef<f32>,
pub c: GpuRef<f32>,
pub m: i32,
pub n: i32,
pub k: i32,
pub alpha: f32,
pub beta: f32,
}
pub struct MemcpyOp {
pub src: GpuRef<f32>,
pub dst: GpuRef<f32>,
}
#[cfg(feature = "curand")]
pub struct RngFillUniformOp {
pub dst: GpuRef<f32>,
}
pub struct BlasRecorder<'a> {
pub handle: &'a CudaBlas,
}
pub struct MemcpyRecorder;
impl RecordMode for MemcpyRecorder {
type Op = MemcpyOp;
fn enqueue_record(
&mut self,
stream: &Arc<cudarc::driver::CudaStream>,
op: Self::Op,
) -> Result<(), GpuError> {
let MemcpyOp { src, dst } = op;
let src_slice = src.access()?.clone();
let dst_slice = dst.access()?.clone();
let mut dst_owned = Arc::try_unwrap(dst_slice)
.map_err(|_| GpuError::Unrecoverable("MemcpyRecorder: dst has multiple refs".into()))?;
stream
.memcpy_dtod(&*src_slice, &mut dst_owned)
.map_err(|e| GpuError::LibraryError {
lib: "driver",
msg: format!("record memcpy_dtod: {e}"),
})?;
dst.record_write(stream);
let _ = (src_slice, dst_owned);
Ok(())
}
}
#[cfg(feature = "curand")]
pub struct RngRecorder<'a> {
pub rng: &'a CudaRng,
}
#[cfg(feature = "cufft")]
pub struct FftR2COp {
pub src: GpuRef<f32>,
pub dst: GpuRef<cudarc::cufft::sys::float2>,
}
#[cfg(feature = "cufft")]
pub struct FftRecorder<'a> {
pub plan: &'a CudaFft,
}
#[cfg(feature = "cufft")]
impl<'a> RecordMode for FftRecorder<'a> {
type Op = FftR2COp;
fn enqueue_record(
&mut self,
stream: &Arc<cudarc::driver::CudaStream>,
op: Self::Op,
) -> Result<(), GpuError> {
let FftR2COp { src, dst } = op;
let src_slice = src.access()?.clone();
let dst_slice = dst.access()?.clone();
let mut dst_owned = Arc::try_unwrap(dst_slice)
.map_err(|_| GpuError::Unrecoverable("FftRecorder: dst has multiple refs".into()))?;
self.plan
.exec_r2c(&*src_slice, &mut dst_owned)
.map_err(|e| GpuError::LibraryError {
lib: "cufft",
msg: format!("record exec_r2c: {e}"),
})?;
dst.record_write(stream);
let _ = (src_slice, dst_owned);
Ok(())
}
}
#[cfg(feature = "curand")]
impl<'a> RecordMode for RngRecorder<'a> {
type Op = RngFillUniformOp;
fn enqueue_record(
&mut self,
stream: &Arc<cudarc::driver::CudaStream>,
op: Self::Op,
) -> Result<(), GpuError> {
let RngFillUniformOp { dst } = op;
let dst_slice = dst.access()?.clone();
let mut owned = Arc::try_unwrap(dst_slice)
.map_err(|_| GpuError::Unrecoverable("RngRecorder: dst has multiple refs".into()))?;
self.rng
.fill_with_uniform(&mut owned)
.map_err(|e| GpuError::LibraryError {
lib: "curand",
msg: format!("record fill_uniform: {e:?}"),
})?;
dst.record_write(stream);
let _ = owned;
Ok(())
}
}
impl<'a> RecordMode for BlasRecorder<'a> {
type Op = BlasSgemmOp;
fn enqueue_record(
&mut self,
stream: &Arc<cudarc::driver::CudaStream>,
op: Self::Op,
) -> Result<(), GpuError> {
let BlasSgemmOp {
a,
b,
c,
m,
n,
k,
alpha,
beta,
} = op;
let a_slice = a.access()?.clone();
let b_slice = b.access()?.clone();
let c_slice = c.access()?.clone();
let mut c_owned = Arc::try_unwrap(c_slice).map_err(|_| {
GpuError::Unrecoverable("BlasRecorder: C has multiple live references".into())
})?;
let cfg = GemmConfig::<f32> {
transa: cublasOperation_t::CUBLAS_OP_N,
transb: cublasOperation_t::CUBLAS_OP_N,
m,
n,
k,
alpha,
lda: m,
ldb: k,
beta,
ldc: m,
};
unsafe {
self.handle
.gemm(cfg, &*a_slice, &*b_slice, &mut c_owned)
.map_err(|e| GpuError::LibraryError {
lib: "cublas",
msg: format!("record gemm: {e}"),
})?;
}
c.record_write(stream);
let _ = (a_slice, b_slice, c_owned);
Ok(())
}
}