pub struct CudaBlas { /* private fields */ }
Expand description
Wrapper around sys::cublasHandle_t
- Create with CudaBlas::new()
- Execute gemm/gemv kernels with Gemv and Gemm. Both f32 and f64 are supported for both
Note: This maintains a instance of Arc<CudaDevice>
, so will prevent the device
from being dropped.
Implementations§
source§impl CudaBlas
impl CudaBlas
sourcepub fn new(device: Arc<CudaDevice>) -> Result<Self, CublasError>
pub fn new(device: Arc<CudaDevice>) -> Result<Self, CublasError>
Creates a new cublas handle and sets the stream to the device
’s stream.
sourcepub unsafe fn set_stream(
&self,
opt_stream: Option<&CudaStream>
) -> Result<(), CublasError>
pub unsafe fn set_stream( &self, opt_stream: Option<&CudaStream> ) -> Result<(), CublasError>
Sets the handle’s current to either the stream specified, or the device’s default work stream.
Safety
This is unsafe because you can end up scheduling multiple concurrent kernels that all write to the same memory address.
Trait Implementations§
source§impl Gemm<f16> for CudaBlas
impl Gemm<f16> for CudaBlas
source§unsafe fn gemm<A: DevicePtr<f16>, B: DevicePtr<f16>, C: DevicePtrMut<f16>>(
&self,
cfg: GemmConfig<f16>,
a: &A,
b: &B,
c: &mut C
) -> Result<(), CublasError>
unsafe fn gemm<A: DevicePtr<f16>, B: DevicePtr<f16>, C: DevicePtrMut<f16>>( &self, cfg: GemmConfig<f16>, a: &A, b: &B, c: &mut C ) -> Result<(), CublasError>
Matrix matrix multiplication. See
nvidia docs Read more
source§unsafe fn gemm_strided_batched<A: DevicePtr<f16>, B: DevicePtr<f16>, C: DevicePtrMut<f16>>(
&self,
cfg: StridedBatchedConfig<f16>,
a: &A,
b: &B,
c: &mut C
) -> Result<(), CublasError>
unsafe fn gemm_strided_batched<A: DevicePtr<f16>, B: DevicePtr<f16>, C: DevicePtrMut<f16>>( &self, cfg: StridedBatchedConfig<f16>, a: &A, b: &B, c: &mut C ) -> Result<(), CublasError>
Batched matrix multiplication with stride support on batch dimension. See
nvidia docs Read more
source§impl Gemm<f32> for CudaBlas
impl Gemm<f32> for CudaBlas
source§unsafe fn gemm<A: DevicePtr<f32>, B: DevicePtr<f32>, C: DevicePtrMut<f32>>(
&self,
cfg: GemmConfig<f32>,
a: &A,
b: &B,
c: &mut C
) -> Result<(), CublasError>
unsafe fn gemm<A: DevicePtr<f32>, B: DevicePtr<f32>, C: DevicePtrMut<f32>>( &self, cfg: GemmConfig<f32>, a: &A, b: &B, c: &mut C ) -> Result<(), CublasError>
Matrix matrix multiplication. See
nvidia docs Read more
source§unsafe fn gemm_strided_batched<A: DevicePtr<f32>, B: DevicePtr<f32>, C: DevicePtrMut<f32>>(
&self,
cfg: StridedBatchedConfig<f32>,
a: &A,
b: &B,
c: &mut C
) -> Result<(), CublasError>
unsafe fn gemm_strided_batched<A: DevicePtr<f32>, B: DevicePtr<f32>, C: DevicePtrMut<f32>>( &self, cfg: StridedBatchedConfig<f32>, a: &A, b: &B, c: &mut C ) -> Result<(), CublasError>
Batched matrix multiplication with stride support on batch dimension. See
nvidia docs Read more
source§impl Gemm<f64> for CudaBlas
impl Gemm<f64> for CudaBlas
source§unsafe fn gemm<A: DevicePtr<f64>, B: DevicePtr<f64>, C: DevicePtrMut<f64>>(
&self,
cfg: GemmConfig<f64>,
a: &A,
b: &B,
c: &mut C
) -> Result<(), CublasError>
unsafe fn gemm<A: DevicePtr<f64>, B: DevicePtr<f64>, C: DevicePtrMut<f64>>( &self, cfg: GemmConfig<f64>, a: &A, b: &B, c: &mut C ) -> Result<(), CublasError>
Matrix matrix multiplication. See
nvidia docs Read more
source§unsafe fn gemm_strided_batched<A: DevicePtr<f64>, B: DevicePtr<f64>, C: DevicePtrMut<f64>>(
&self,
cfg: StridedBatchedConfig<f64>,
a: &A,
b: &B,
c: &mut C
) -> Result<(), CublasError>
unsafe fn gemm_strided_batched<A: DevicePtr<f64>, B: DevicePtr<f64>, C: DevicePtrMut<f64>>( &self, cfg: StridedBatchedConfig<f64>, a: &A, b: &B, c: &mut C ) -> Result<(), CublasError>
Batched matrix multiplication with stride support on batch dimension. See
nvidia docs Read more
impl Send for CudaBlas
impl Sync for CudaBlas
Auto Trait Implementations§
Blanket Implementations§
source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere T: ?Sized,
source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more