use std::marker::PhantomData;
use std::ptr::NonNull;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use async_trait::async_trait;
use atomr_core::actor::{Actor, Context, Props};
use cudarc::driver::sys as driver_sys;
use cudarc::runtime::sys as runtime_sys;
use tokio::sync::oneshot;
use crate::error::GpuError;
fn driver_location(target: PrefetchTarget) -> driver_sys::CUmemLocation {
unsafe {
let mut loc: driver_sys::CUmemLocation = std::mem::zeroed();
loc.type_ = match target {
PrefetchTarget::Device(_) => driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
PrefetchTarget::Cpu => driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_HOST,
};
loc
}
}
#[derive(Debug, Clone, Copy)]
pub enum ManagedFlags {
AttachGlobal,
AttachHost,
}
impl ManagedFlags {
fn raw(self) -> u32 {
match self {
ManagedFlags::AttachGlobal => runtime_sys::cudaMemAttachGlobal,
ManagedFlags::AttachHost => runtime_sys::cudaMemAttachHost,
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum PrefetchTarget {
Device(u32),
Cpu,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ManagedStats {
pub allocations: usize,
pub bytes_allocated: usize,
}
pub struct ManagedRef<T> {
inner: Option<Arc<ManagedRefInner>>,
_marker: PhantomData<T>,
}
struct ManagedRefInner {
ptr: NonNull<u8>,
bytes: usize,
elements: usize,
system_alive: Arc<AtomicBool>,
}
impl Drop for ManagedRefInner {
fn drop(&mut self) {
if self.system_alive.load(Ordering::Acquire) {
unsafe {
let _ = runtime_sys::cudaFree(self.ptr.as_ptr() as *mut _);
}
}
}
}
unsafe impl Send for ManagedRefInner {}
unsafe impl Sync for ManagedRefInner {}
impl<T> ManagedRef<T> {
pub fn is_valid(&self) -> bool {
self.inner
.as_ref()
.map(|i| i.system_alive.load(Ordering::Acquire))
.unwrap_or(false)
}
pub fn len(&self) -> usize {
self.inner.as_ref().map(|i| i.elements).unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn as_ptr(&self) -> *const T {
self.inner
.as_ref()
.map(|i| i.ptr.as_ptr() as *const T)
.unwrap_or(std::ptr::null())
}
pub fn as_mut_ptr(&self) -> *mut T {
self.inner
.as_ref()
.map(|i| i.ptr.as_ptr() as *mut T)
.unwrap_or(std::ptr::null_mut())
}
}
impl<T: Copy> ManagedRef<T> {
pub fn as_slice(&self) -> &[T] {
match self.inner.as_ref() {
None => &[],
Some(i) => {
if !i.system_alive.load(Ordering::Acquire) {
return &[];
}
unsafe { std::slice::from_raw_parts(i.ptr.as_ptr() as *const T, i.elements) }
}
}
}
pub fn as_mut_slice(&mut self) -> &mut [T] {
match self.inner.as_ref() {
None => &mut [],
Some(i) => {
if !i.system_alive.load(Ordering::Acquire) {
return &mut [];
}
unsafe { std::slice::from_raw_parts_mut(i.ptr.as_ptr() as *mut T, i.elements) }
}
}
}
}
impl<T> Clone for ManagedRef<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
_marker: PhantomData,
}
}
}
unsafe impl<T: Send> Send for ManagedRef<T> {}
unsafe impl<T: Sync> Sync for ManagedRef<T> {}
pub enum ManagedMsg {
AllocateManagedF32 {
len: usize,
flags: ManagedFlags,
reply: oneshot::Sender<Result<ManagedRef<f32>, GpuError>>,
},
PrefetchF32 {
mem: ManagedRef<f32>,
target: PrefetchTarget,
reply: oneshot::Sender<Result<(), GpuError>>,
},
AdviseF32 {
mem: ManagedRef<f32>,
advice: super::advise::MemAdvice,
reply: oneshot::Sender<Result<(), GpuError>>,
},
Stats {
reply: oneshot::Sender<ManagedStats>,
},
}
pub struct ManagedAllocatorActor {
system_alive: Arc<AtomicBool>,
stats: ManagedStats,
}
impl ManagedAllocatorActor {
pub fn props() -> Props<Self> {
Props::create(|| ManagedAllocatorActor {
system_alive: Arc::new(AtomicBool::new(true)),
stats: ManagedStats::default(),
})
}
fn allocate_f32(
&mut self,
len: usize,
flags: ManagedFlags,
) -> Result<ManagedRef<f32>, GpuError> {
let bytes = len.checked_mul(std::mem::size_of::<f32>()).ok_or_else(|| {
GpuError::Unrecoverable("managed alloc: len * size_of overflowed".into())
})?;
let mut raw: *mut std::ffi::c_void = std::ptr::null_mut();
let raw_ref = &mut raw as *mut *mut std::ffi::c_void;
let raw_ref = raw_ref as usize; let status_res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
unsafe {
runtime_sys::cudaMallocManaged(
raw_ref as *mut *mut std::ffi::c_void,
bytes,
flags.raw(),
)
}
}));
let status = match status_res {
Ok(s) => s,
Err(_) => {
return Err(GpuError::Unrecoverable(
"cudaMallocManaged: CUDA runtime not loadable".into(),
));
}
};
if status != runtime_sys::cudaError::cudaSuccess {
return Err(GpuError::OutOfMemory(format!(
"cudaMallocManaged({bytes}B): {status:?}"
)));
}
let ptr = NonNull::new(raw as *mut u8)
.ok_or_else(|| GpuError::Unrecoverable("cudaMallocManaged returned null".into()))?;
self.stats.allocations += 1;
self.stats.bytes_allocated += bytes;
Ok(ManagedRef {
inner: Some(Arc::new(ManagedRefInner {
ptr,
bytes,
elements: len,
system_alive: self.system_alive.clone(),
})),
_marker: PhantomData,
})
}
}
#[async_trait]
impl Actor for ManagedAllocatorActor {
type Msg = ManagedMsg;
async fn handle(&mut self, _ctx: &mut Context<Self>, msg: ManagedMsg) {
match msg {
ManagedMsg::AllocateManagedF32 { len, flags, reply } => {
let _ = reply.send(self.allocate_f32(len, flags));
}
ManagedMsg::PrefetchF32 { mem, target, reply } => {
let Some(inner) = mem.inner.as_ref() else {
let _ = reply.send(Err(GpuError::Unrecoverable(
"PrefetchF32: invalid ManagedRef".into(),
)));
return;
};
if !inner.system_alive.load(Ordering::Acquire) {
let _ = reply.send(Err(GpuError::Unrecoverable(
"PrefetchF32: allocator stopped".into(),
)));
return;
}
let location = driver_location(target);
let dev_ptr = inner.ptr.as_ptr() as cudarc::driver::sys::CUdeviceptr;
let r = crate::sys::cuda_driver::mem_prefetch_async_v2(
dev_ptr,
inner.bytes,
location,
0,
std::ptr::null_mut(),
);
let _ = reply.send(r);
}
ManagedMsg::AdviseF32 { mem, advice, reply } => {
let Some(inner) = mem.inner.as_ref() else {
let _ = reply.send(Err(GpuError::Unrecoverable(
"AdviseF32: invalid ManagedRef".into(),
)));
return;
};
if !inner.system_alive.load(Ordering::Acquire) {
let _ = reply.send(Err(GpuError::Unrecoverable(
"AdviseF32: allocator stopped".into(),
)));
return;
}
let dev_ptr = inner.ptr.as_ptr() as cudarc::driver::sys::CUdeviceptr;
let r = crate::memory::advise::advise(dev_ptr, inner.bytes, advice);
let _ = reply.send(r);
}
ManagedMsg::Stats { reply } => {
let _ = reply.send(self.stats);
}
}
}
async fn post_stop(&mut self, _ctx: &mut Context<Self>) {
self.system_alive.store(false, Ordering::Release);
}
}
#[cfg(test)]
mod tests {
use super::*;
use atomr_config::Config;
use atomr_core::actor::ActorSystem;
use std::time::Duration;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn allocate_replies_then_invalidate_on_post_stop() {
let sys = ActorSystem::create("managed-test", Config::empty())
.await
.unwrap();
let mgr = sys
.actor_of(ManagedAllocatorActor::props(), "managed")
.unwrap();
let (tx, rx) = oneshot::channel();
mgr.tell(ManagedMsg::AllocateManagedF32 {
len: 1024,
flags: ManagedFlags::AttachGlobal,
reply: tx,
});
let r = tokio::time::timeout(Duration::from_secs(2), rx)
.await
.unwrap()
.unwrap();
let _ = r;
let (tx, rx) = oneshot::channel();
mgr.tell(ManagedMsg::Stats { reply: tx });
let _stats = tokio::time::timeout(Duration::from_secs(2), rx)
.await
.unwrap()
.unwrap();
sys.terminate().await;
}
fn synthetic_managed_ref<T>(elements: usize) -> (ManagedRef<T>, Arc<AtomicBool>) {
let alive = Arc::new(AtomicBool::new(true));
let mut buf = Box::<u8>::new(0u8);
let raw = NonNull::new(&mut *buf as *mut u8).unwrap();
std::mem::forget(buf); let mref = ManagedRef::<T> {
inner: Some(Arc::new(ManagedRefInner {
ptr: raw,
bytes: elements * std::mem::size_of::<T>(),
elements,
system_alive: alive.clone(),
})),
_marker: PhantomData,
};
(mref, alive)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn prefetch_message_routes_through_actor() {
let sys = ActorSystem::create("managed-prefetch-test", Config::empty())
.await
.unwrap();
let mgr = sys
.actor_of(ManagedAllocatorActor::props(), "managed")
.unwrap();
let (mref, alive) = synthetic_managed_ref::<f32>(64);
let (tx, rx) = oneshot::channel();
mgr.tell(ManagedMsg::PrefetchF32 {
mem: mref.clone(),
target: PrefetchTarget::Cpu,
reply: tx,
});
let r = tokio::time::timeout(Duration::from_secs(2), rx)
.await
.unwrap()
.unwrap();
match r {
Ok(()) => {}
Err(GpuError::Unrecoverable(_)) => {}
Err(GpuError::LibraryError { .. }) => {}
other => panic!("unexpected: {other:?}"),
}
alive.store(false, Ordering::Release);
drop(mref);
sys.terminate().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn advise_message_routes_through_actor() {
let sys = ActorSystem::create("managed-advise-test", Config::empty())
.await
.unwrap();
let mgr = sys
.actor_of(ManagedAllocatorActor::props(), "managed")
.unwrap();
let (mref, alive) = synthetic_managed_ref::<f32>(64);
let (tx, rx) = oneshot::channel();
mgr.tell(ManagedMsg::AdviseF32 {
mem: mref.clone(),
advice: super::super::advise::MemAdvice::SetReadMostly,
reply: tx,
});
let r = tokio::time::timeout(Duration::from_secs(2), rx)
.await
.unwrap()
.unwrap();
match r {
Ok(()) => {}
Err(GpuError::Unrecoverable(_)) => {}
Err(GpuError::LibraryError { .. }) => {}
other => panic!("unexpected: {other:?}"),
}
alive.store(false, Ordering::Release);
drop(mref);
sys.terminate().await;
}
}