use std::sync::{Arc, Weak};
use arc_swap::ArcSwapOption;
use crate::device::DeviceState;
use crate::error::GpuError;
pub struct GpuRef<T> {
inner: Arc<GpuRefInner<T>>,
}
struct GpuRefInner<T> {
slice: Arc<cudarc::driver::CudaSlice<T>>,
generation: u64,
state: Weak<DeviceState>,
last_write_stream: ArcSwapOption<cudarc::driver::CudaStream>,
}
impl<T> Clone for GpuRef<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl<T> std::fmt::Debug for GpuRef<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GpuRef")
.field("generation", &self.inner.generation)
.field("len", &self.inner.slice.len())
.finish()
}
}
impl<T> GpuRef<T> {
pub fn new(slice: Arc<cudarc::driver::CudaSlice<T>>, state: &Arc<DeviceState>) -> Self {
let generation = state.generation();
Self {
inner: Arc::new(GpuRefInner {
slice,
generation,
state: Arc::downgrade(state),
last_write_stream: ArcSwapOption::empty(),
}),
}
}
pub fn access(&self) -> Result<&Arc<cudarc::driver::CudaSlice<T>>, GpuError> {
let state = self
.inner
.state
.upgrade()
.ok_or(GpuError::GpuRefStale("device state dropped"))?;
if !state.accepting_ops() {
return Err(GpuError::GpuRefStale("device shutting down"));
}
if state.generation() != self.inner.generation {
return Err(GpuError::GpuRefStale("context rebuilt"));
}
Ok(&self.inner.slice)
}
pub fn generation(&self) -> u64 {
self.inner.generation
}
pub fn len(&self) -> usize {
self.inner.slice.len()
}
pub fn is_empty(&self) -> bool {
self.inner.slice.is_empty()
}
pub fn device_id(&self) -> Option<u32> {
self.inner.state.upgrade().map(|s| s.device_id())
}
pub fn record_write(&self, stream: &Arc<cudarc::driver::CudaStream>) {
self.inner.last_write_stream.store(Some(stream.clone()));
}
pub fn last_write_stream(&self) -> Option<Arc<cudarc::driver::CudaStream>> {
self.inner.last_write_stream.load_full()
}
pub fn raw_device_ptr(&self) -> Result<u64, GpuError> {
use cudarc::driver::DevicePtr;
let slice = self.access()?;
let stream = slice.stream();
let (ptr, _guard) = slice.device_ptr(stream);
Ok(ptr)
}
}
#[cfg(test)]
impl<T> GpuRef<T> {
pub(crate) fn for_test_no_gpu_leaked() -> Self {
use std::mem::MaybeUninit;
let boxed: Box<MaybeUninit<cudarc::driver::CudaSlice<T>>> = Box::new(MaybeUninit::uninit());
let leaked: *mut MaybeUninit<cudarc::driver::CudaSlice<T>> = Box::into_raw(boxed);
let arc_slice: std::sync::Arc<cudarc::driver::CudaSlice<T>> =
unsafe { std::sync::Arc::from_raw(leaked as *const cudarc::driver::CudaSlice<T>) };
let state = std::sync::Arc::new(crate::device::DeviceState::new(0));
Self::new(arc_slice, &state)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::device::DeviceState;
#[test]
fn generation_mismatch_fails_validate() {
let state = Arc::new(DeviceState::new(0));
assert_eq!(state.generation(), 0);
state.bump_generation();
assert_eq!(state.generation(), 1);
assert!(state.accepting_ops());
state.begin_shutdown();
assert!(!state.accepting_ops());
}
}