#[cfg(feature = "cuda")]
use std::sync::Arc;
#[cfg(feature = "cuda")]
use cudarc::driver::{CudaSlice, CudaStream, DeviceRepr, ValidAsZeroBits};
use crate::error::{GpuError, GpuResult};
#[cfg(feature = "cuda")]
pub struct DeviceScalar<T: DeviceRepr + ValidAsZeroBits + Copy> {
buf: CudaSlice<T>,
stream: Arc<CudaStream>,
}
#[cfg(feature = "cuda")]
impl<T: DeviceRepr + ValidAsZeroBits + Copy> DeviceScalar<T> {
pub fn new(stream: &Arc<CudaStream>, initial: T) -> GpuResult<Self> {
let buf = stream.clone_htod(&[initial])?;
Ok(Self {
buf,
stream: Arc::clone(stream),
})
}
pub fn update(&mut self, value: T) -> GpuResult<()> {
self.stream.memcpy_htod(&[value], &mut self.buf)?;
Ok(())
}
#[inline]
pub fn inner(&self) -> &CudaSlice<T> {
&self.buf
}
}
#[cfg(feature = "cuda")]
pub struct CapturedGraph {
graph: cudarc::driver::CudaGraph,
}
#[cfg(feature = "cuda")]
impl CapturedGraph {
pub fn launch(&self) -> GpuResult<()> {
self.graph.launch()?;
Ok(())
}
}
#[cfg(feature = "cuda")]
pub fn begin_capture(stream: &Arc<CudaStream>) -> GpuResult<()> {
stream.begin_capture(
cudarc::driver::sys::CUstreamCaptureMode::CU_STREAM_CAPTURE_MODE_THREAD_LOCAL,
)?;
Ok(())
}
#[cfg(feature = "cuda")]
pub fn end_capture(stream: &Arc<CudaStream>) -> GpuResult<CapturedGraph> {
let flags = cudarc::driver::sys::CUgraphInstantiate_flags_enum::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH;
let graph = stream
.end_capture(flags)?
.ok_or(GpuError::PtxCompileFailed {
kernel: "CUDA graph capture returned null",
})?;
Ok(CapturedGraph { graph })
}
#[cfg(not(feature = "cuda"))]
pub struct DeviceScalar<T: Copy> {
_phantom: std::marker::PhantomData<T>,
}
#[cfg(not(feature = "cuda"))]
pub struct CapturedGraph;
#[cfg(not(feature = "cuda"))]
impl CapturedGraph {
pub fn launch(&self) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
}
#[cfg(not(feature = "cuda"))]
pub fn begin_capture<T>(_stream: &T) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn end_capture<T>(_stream: &T) -> GpuResult<CapturedGraph> {
Err(GpuError::NoCudaFeature)
}