use std::marker::PhantomData;
use std::ptr::NonNull;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use async_trait::async_trait;
use atomr_core::actor::{Actor, Context, Props};
use cudarc::runtime::sys as runtime_sys;
use tokio::sync::oneshot;
use crate::error::GpuError;
#[derive(Debug, Clone, Copy)]
pub enum ManagedFlags {
AttachGlobal,
AttachHost,
}
impl ManagedFlags {
fn raw(self) -> u32 {
match self {
ManagedFlags::AttachGlobal => runtime_sys::cudaMemAttachGlobal,
ManagedFlags::AttachHost => runtime_sys::cudaMemAttachHost,
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum PrefetchTarget {
Device(u32),
Cpu,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ManagedStats {
pub allocations: usize,
pub bytes_allocated: usize,
}
pub struct ManagedRef<T> {
inner: Option<Arc<ManagedRefInner>>,
_marker: PhantomData<T>,
}
struct ManagedRefInner {
ptr: NonNull<u8>,
bytes: usize,
elements: usize,
system_alive: Arc<AtomicBool>,
}
impl Drop for ManagedRefInner {
fn drop(&mut self) {
if self.system_alive.load(Ordering::Acquire) {
unsafe {
let _ = runtime_sys::cudaFree(self.ptr.as_ptr() as *mut _);
}
}
}
}
unsafe impl Send for ManagedRefInner {}
unsafe impl Sync for ManagedRefInner {}
impl<T> ManagedRef<T> {
pub fn is_valid(&self) -> bool {
self.inner
.as_ref()
.map(|i| i.system_alive.load(Ordering::Acquire))
.unwrap_or(false)
}
pub fn len(&self) -> usize {
self.inner.as_ref().map(|i| i.elements).unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn as_ptr(&self) -> *const T {
self.inner
.as_ref()
.map(|i| i.ptr.as_ptr() as *const T)
.unwrap_or(std::ptr::null())
}
pub fn as_mut_ptr(&self) -> *mut T {
self.inner
.as_ref()
.map(|i| i.ptr.as_ptr() as *mut T)
.unwrap_or(std::ptr::null_mut())
}
}
impl<T: Copy> ManagedRef<T> {
pub fn as_slice(&self) -> &[T] {
match self.inner.as_ref() {
None => &[],
Some(i) => {
if !i.system_alive.load(Ordering::Acquire) {
return &[];
}
unsafe { std::slice::from_raw_parts(i.ptr.as_ptr() as *const T, i.elements) }
}
}
}
pub fn as_mut_slice(&mut self) -> &mut [T] {
match self.inner.as_ref() {
None => &mut [],
Some(i) => {
if !i.system_alive.load(Ordering::Acquire) {
return &mut [];
}
unsafe { std::slice::from_raw_parts_mut(i.ptr.as_ptr() as *mut T, i.elements) }
}
}
}
}
impl<T> Clone for ManagedRef<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
_marker: PhantomData,
}
}
}
unsafe impl<T: Send> Send for ManagedRef<T> {}
unsafe impl<T: Sync> Sync for ManagedRef<T> {}
pub enum ManagedMsg {
AllocateManagedF32 {
len: usize,
flags: ManagedFlags,
reply: oneshot::Sender<Result<ManagedRef<f32>, GpuError>>,
},
PrefetchF32 {
mem: ManagedRef<f32>,
target: PrefetchTarget,
reply: oneshot::Sender<Result<(), GpuError>>,
},
Stats {
reply: oneshot::Sender<ManagedStats>,
},
}
pub struct ManagedAllocatorActor {
system_alive: Arc<AtomicBool>,
stats: ManagedStats,
}
impl ManagedAllocatorActor {
pub fn props() -> Props<Self> {
Props::create(|| ManagedAllocatorActor {
system_alive: Arc::new(AtomicBool::new(true)),
stats: ManagedStats::default(),
})
}
fn allocate_f32(
&mut self,
len: usize,
flags: ManagedFlags,
) -> Result<ManagedRef<f32>, GpuError> {
let bytes = len.checked_mul(std::mem::size_of::<f32>()).ok_or_else(|| {
GpuError::Unrecoverable("managed alloc: len * size_of overflowed".into())
})?;
let mut raw: *mut std::ffi::c_void = std::ptr::null_mut();
let raw_ref = &mut raw as *mut *mut std::ffi::c_void;
let raw_ref = raw_ref as usize; let status_res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
unsafe {
runtime_sys::cudaMallocManaged(
raw_ref as *mut *mut std::ffi::c_void,
bytes,
flags.raw(),
)
}
}));
let status = match status_res {
Ok(s) => s,
Err(_) => {
return Err(GpuError::Unrecoverable(
"cudaMallocManaged: CUDA runtime not loadable".into(),
));
}
};
if status != runtime_sys::cudaError::cudaSuccess {
return Err(GpuError::OutOfMemory(format!(
"cudaMallocManaged({bytes}B): {status:?}"
)));
}
let ptr = NonNull::new(raw as *mut u8)
.ok_or_else(|| GpuError::Unrecoverable("cudaMallocManaged returned null".into()))?;
self.stats.allocations += 1;
self.stats.bytes_allocated += bytes;
Ok(ManagedRef {
inner: Some(Arc::new(ManagedRefInner {
ptr,
bytes,
elements: len,
system_alive: self.system_alive.clone(),
})),
_marker: PhantomData,
})
}
}
#[async_trait]
impl Actor for ManagedAllocatorActor {
type Msg = ManagedMsg;
async fn handle(&mut self, _ctx: &mut Context<Self>, msg: ManagedMsg) {
match msg {
ManagedMsg::AllocateManagedF32 { len, flags, reply } => {
let _ = reply.send(self.allocate_f32(len, flags));
}
ManagedMsg::PrefetchF32 { mem, target, reply } => {
let Some(inner) = mem.inner.as_ref() else {
let _ = reply.send(Err(GpuError::Unrecoverable(
"PrefetchF32: invalid ManagedRef".into(),
)));
return;
};
if !inner.system_alive.load(Ordering::Acquire) {
let _ = reply.send(Err(GpuError::Unrecoverable(
"PrefetchF32: allocator stopped".into(),
)));
return;
}
let location = runtime_sys::cudaMemLocation {
type_: match target {
PrefetchTarget::Device(_) => {
runtime_sys::cudaMemLocationType::cudaMemLocationTypeDevice
}
PrefetchTarget::Cpu => {
runtime_sys::cudaMemLocationType::cudaMemLocationTypeHost
}
},
id: match target {
PrefetchTarget::Device(d) => d as i32,
PrefetchTarget::Cpu => 0,
},
};
let status = unsafe {
runtime_sys::cudaMemPrefetchAsync(
inner.ptr.as_ptr() as *const _,
inner.bytes,
location,
0,
std::ptr::null_mut(),
)
};
if status != runtime_sys::cudaError::cudaSuccess {
let _ = reply.send(Err(GpuError::LibraryError {
lib: "runtime",
msg: format!("cudaMemPrefetchAsync: {status:?}"),
}));
return;
}
let _ = reply.send(Ok(()));
}
ManagedMsg::Stats { reply } => {
let _ = reply.send(self.stats);
}
}
}
async fn post_stop(&mut self, _ctx: &mut Context<Self>) {
self.system_alive.store(false, Ordering::Release);
}
}
#[cfg(test)]
mod tests {
use super::*;
use atomr_config::Config;
use atomr_core::actor::ActorSystem;
use std::time::Duration;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn allocate_replies_then_invalidate_on_post_stop() {
let sys = ActorSystem::create("managed-test", Config::empty())
.await
.unwrap();
let mgr = sys
.actor_of(ManagedAllocatorActor::props(), "managed")
.unwrap();
let (tx, rx) = oneshot::channel();
mgr.tell(ManagedMsg::AllocateManagedF32 {
len: 1024,
flags: ManagedFlags::AttachGlobal,
reply: tx,
});
let r = tokio::time::timeout(Duration::from_secs(2), rx)
.await
.unwrap()
.unwrap();
let _ = r;
let (tx, rx) = oneshot::channel();
mgr.tell(ManagedMsg::Stats { reply: tx });
let _stats = tokio::time::timeout(Duration::from_secs(2), rx)
.await
.unwrap()
.unwrap();
sys.terminate().await;
}
}