use std::collections::HashMap;
use std::ffi::{c_void, CStr};
use std::panic::{catch_unwind, AssertUnwindSafe};
use std::sync::{Arc, Mutex, OnceLock};
use kapsl_hal::gpu_arena::{GpuBlockPool, GpuPoolHandle};
use ort::memory::{AllocationDevice, AllocatorType, MemoryInfo, MemoryType};
use ort::sys as ort_sys;
use ort::AsPointer;
pub const USE_ENV_ALLOCATORS_KEY: &str = "session.use_env_allocators";
const CUDA_ALLOC_ALIGN: usize = 256;
struct BlockRun {
first: u32,
count: usize,
}
struct AllocState {
pool: Arc<GpuBlockPool>,
mem_info: MemoryInfo,
bytes_per_block: usize,
device_id: i32,
live: Mutex<HashMap<usize, BlockRun>>,
}
#[repr(C)]
struct PoolOrtAllocator {
ort: ort_sys::OrtAllocator,
state: AllocState,
}
unsafe extern "system" fn pool_alloc(
this_: *mut ort_sys::OrtAllocator,
size: usize,
) -> *mut c_void {
catch_unwind(AssertUnwindSafe(|| {
let state = unsafe { &(*(this_ as *const PoolOrtAllocator)).state };
if size == 0 {
return std::ptr::null_mut();
}
let blocks = size.div_ceil(state.bytes_per_block);
match state.pool.alloc_blocks_contiguous(blocks) {
Ok(first) => {
let ptr = state.pool.block_device_ptr(first) as usize;
state.live.lock().unwrap().insert(ptr, BlockRun { first, count: blocks });
ptr as *mut c_void
}
Err(e) => {
log::error!(
"ORT pool allocator (device {}): failed to allocate {} bytes ({} blocks): {}",
state.device_id, size, blocks, e
);
std::ptr::null_mut()
}
}
}))
.unwrap_or(std::ptr::null_mut())
}
unsafe extern "system" fn pool_free(this_: *mut ort_sys::OrtAllocator, p: *mut c_void) {
let _ = catch_unwind(AssertUnwindSafe(|| {
if p.is_null() {
return;
}
let state = unsafe { &(*(this_ as *const PoolOrtAllocator)).state };
match state.live.lock().unwrap().remove(&(p as usize)) {
Some(run) => state.pool.free_blocks_contiguous(run.first, run.count),
None => log::warn!(
"ORT pool allocator (device {}): free of unknown pointer {:p}",
state.device_id, p
),
}
}));
}
unsafe extern "system" fn pool_info(
this_: *const ort_sys::OrtAllocator,
) -> *const ort_sys::OrtMemoryInfo {
let state = unsafe { &(*(this_ as *const PoolOrtAllocator)).state };
state.mem_info.ptr()
}
unsafe extern "system" fn pool_reserve(
this_: *const ort_sys::OrtAllocator,
size: usize,
) -> *mut c_void {
unsafe { pool_alloc(this_ as *mut ort_sys::OrtAllocator, size) }
}
struct Registration {
handle: GpuPoolHandle,
_env: Arc<ort::environment::Environment>,
}
static REGISTRY: OnceLock<Mutex<HashMap<i32, Registration>>> = OnceLock::new();
fn registry() -> &'static Mutex<HashMap<i32, Registration>> {
REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
}
fn status_to_result(status: ort_sys::OrtStatusPtr) -> Result<(), String> {
if status.0.is_null() {
return Ok(());
}
unsafe {
let msg = CStr::from_ptr((ort::api().GetErrorMessage)(status.0))
.to_string_lossy()
.into_owned();
(ort::api().ReleaseStatus)(status.0);
Err(msg)
}
}
pub fn register_pool_allocator(device_id: i32, handle: &GpuPoolHandle) -> Result<(), String> {
let pool = handle.pool.clone();
let bytes_per_block = pool.bytes_per_block();
if bytes_per_block == 0 || bytes_per_block % CUDA_ALLOC_ALIGN != 0 {
return Err(format!(
"pool block size ({bytes_per_block} bytes) is not a multiple of the required \
{CUDA_ALLOC_ALIGN}-byte CUDA alignment; cannot serve ORT allocations"
));
}
let mut reg = registry().lock().unwrap();
if let Some(existing) = reg.get(&device_id) {
if Arc::ptr_eq(&existing.handle.pool, &pool) {
return Ok(());
}
return Err(format!(
"a different GpuBlockPool is already registered with ORT for device {device_id}"
));
}
let env = ort::environment::get_environment()
.map_err(|e| format!("failed to obtain ORT environment: {e}"))?;
let mem_info = MemoryInfo::new(
AllocationDevice::CUDA,
device_id,
AllocatorType::Device,
MemoryType::Default,
)
.map_err(|e| format!("failed to create ORT memory info: {e}"))?;
let allocator = Box::new(PoolOrtAllocator {
ort: ort_sys::OrtAllocator {
version: ort_sys::ORT_API_VERSION,
Alloc: Some(pool_alloc),
Free: Some(pool_free),
Info: Some(pool_info),
Reserve: Some(pool_reserve),
},
state: AllocState {
pool: pool.clone(),
mem_info,
bytes_per_block,
device_id,
live: Mutex::new(HashMap::new()),
},
});
let allocator: &'static mut PoolOrtAllocator = Box::leak(allocator);
let status = unsafe {
(ort::api().RegisterAllocator)(env.ptr().cast_mut(), &mut allocator.ort)
};
status_to_result(status).map_err(|e| format!("ORT RegisterAllocator failed: {e}"))?;
log::info!(
"Registered shared GPU pool with ORT for device {}: {} blocks × {} KiB ({} MiB)",
device_id,
pool.total_blocks(),
bytes_per_block / 1024,
pool.capacity_bytes() / (1024 * 1024),
);
reg.insert(
device_id,
Registration {
handle: handle.clone(),
_env: env,
},
);
Ok(())
}
pub fn is_registered(device_id: i32) -> bool {
registry().lock().unwrap().contains_key(&device_id)
}
pub fn registered_pool_handle(device_id: i32) -> Option<GpuPoolHandle> {
registry()
.lock()
.unwrap()
.get(&device_id)
.map(|r| r.handle.clone())
}