#[cfg(feature = "cuda")]
use std::sync::Arc;
#[cfg(feature = "cuda")]
use cudarc::driver::{CudaContext, CudaStream};
use crate::error::GpuResult;
#[cfg(not(feature = "cuda"))]
use crate::error::GpuError;
#[cfg(feature = "cuda")]
pub struct GpuDevice {
ctx: Arc<CudaContext>,
stream: Arc<CudaStream>,
ordinal: usize,
}
#[cfg(feature = "cuda")]
impl GpuDevice {
pub fn new(ordinal: usize) -> GpuResult<Self> {
let ctx = CudaContext::new(ordinal)?;
let stream = ctx.default_stream();
Ok(Self {
ctx,
stream,
ordinal,
})
}
#[inline]
pub fn context(&self) -> &Arc<CudaContext> {
&self.ctx
}
#[inline]
pub fn stream(&self) -> &Arc<CudaStream> {
&self.stream
}
#[inline]
pub fn ordinal(&self) -> usize {
self.ordinal
}
}
#[cfg(feature = "cuda")]
impl Clone for GpuDevice {
fn clone(&self) -> Self {
Self {
ctx: Arc::clone(&self.ctx),
stream: Arc::clone(&self.stream),
ordinal: self.ordinal,
}
}
}
#[cfg(feature = "cuda")]
impl std::fmt::Debug for GpuDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GpuDevice")
.field("ordinal", &self.ordinal)
.finish_non_exhaustive()
}
}
#[cfg(not(feature = "cuda"))]
#[derive(Clone, Debug)]
pub struct GpuDevice {
ordinal: usize,
}
#[cfg(not(feature = "cuda"))]
impl GpuDevice {
pub fn new(ordinal: usize) -> GpuResult<Self> {
let _ = ordinal;
Err(GpuError::NoCudaFeature)
}
#[inline]
pub fn ordinal(&self) -> usize {
self.ordinal
}
}