#[cfg(feature = "cuda")]
use std::sync::Arc;
#[cfg(feature = "cuda")]
use cudarc::cublas::CudaBlas;
#[cfg(feature = "cuda")]
use cudarc::driver::{CudaContext, CudaStream};
use crate::error::GpuResult;
#[cfg(not(feature = "cuda"))]
use crate::error::GpuError;
#[cfg(feature = "cuda")]
pub struct GpuDevice {
ctx: Arc<CudaContext>,
stream: Arc<CudaStream>,
blas: CudaBlas,
ordinal: usize,
}
#[cfg(feature = "cuda")]
impl GpuDevice {
pub fn new(ordinal: usize) -> GpuResult<Self> {
let ctx = CudaContext::new(ordinal)?;
let stream = ctx.default_stream();
let blas = CudaBlas::new(stream.clone())?;
Ok(Self { ctx, stream, blas, ordinal })
}
pub fn fork_for_capture(parent: &GpuDevice) -> GpuResult<Self> {
let stream = parent.stream.fork()?;
let blas = CudaBlas::new(stream.clone())?;
Ok(Self {
ctx: Arc::clone(&parent.ctx),
stream,
blas,
ordinal: parent.ordinal,
})
}
#[inline]
pub fn context(&self) -> &Arc<CudaContext> { &self.ctx }
#[inline]
pub fn stream(&self) -> &Arc<CudaStream> { &self.stream }
#[inline]
pub fn blas(&self) -> &CudaBlas { &self.blas }
#[inline]
pub fn ordinal(&self) -> usize { self.ordinal }
}
#[cfg(feature = "cuda")]
impl Clone for GpuDevice {
fn clone(&self) -> Self {
let blas = CudaBlas::new(self.stream.clone())
.expect("CudaBlas::new failed in GpuDevice::clone");
Self {
ctx: Arc::clone(&self.ctx),
stream: Arc::clone(&self.stream),
blas,
ordinal: self.ordinal,
}
}
}
#[cfg(feature = "cuda")]
impl std::fmt::Debug for GpuDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GpuDevice")
.field("ordinal", &self.ordinal)
.finish_non_exhaustive()
}
}
#[cfg(not(feature = "cuda"))]
#[derive(Clone, Debug)]
pub struct GpuDevice {
ordinal: usize,
}
#[cfg(not(feature = "cuda"))]
impl GpuDevice {
pub fn new(ordinal: usize) -> GpuResult<Self> {
let _ = ordinal;
Err(GpuError::NoCudaFeature)
}
#[inline]
pub fn ordinal(&self) -> usize {
self.ordinal
}
}