use std::mem::ManuallyDrop;
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::sync::Arc;
use cudarc::driver::{CudaSlice, CudaStream, DevicePtr, DevicePtrMut, DeviceSlice, SyncOnDrop};
use xlog_core::{MemoryBudget, Result, Schema, XlogError};
use crate::arrow_device::ArrowDeviceImport;
use crate::cuda_compat::{AsKernelParam, DeviceParamStorage, IntoKernelParamStorage};
use crate::dlpack::DlpackManagedTensor;
use crate::CudaDevice;
pub struct GpuMemoryManager {
device: Arc<CudaDevice>,
budget: MemoryBudget,
allocated: AtomicU64,
}
pub struct TrackedCudaSlice<T: cudarc::driver::DeviceRepr> {
bytes: u64,
manager: Arc<GpuMemoryManager>,
inner: CudaSlice<T>,
raw_ptr: cudarc::driver::sys::CUdeviceptr,
}
impl<T: cudarc::driver::DeviceRepr> Deref for TrackedCudaSlice<T> {
type Target = CudaSlice<T>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<T: cudarc::driver::DeviceRepr> DerefMut for TrackedCudaSlice<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl<T: cudarc::driver::DeviceRepr> DeviceSlice<T> for TrackedCudaSlice<T> {
fn len(&self) -> usize {
self.inner.len()
}
fn stream(&self) -> &Arc<CudaStream> {
self.inner.stream()
}
}
impl<T: cudarc::driver::DeviceRepr> DevicePtr<T> for TrackedCudaSlice<T> {
fn device_ptr<'a>(
&'a self,
stream: &'a CudaStream,
) -> (cudarc::driver::sys::CUdeviceptr, SyncOnDrop<'a>) {
DevicePtr::device_ptr(&self.inner, stream)
}
}
impl<T: cudarc::driver::DeviceRepr> DevicePtrMut<T> for TrackedCudaSlice<T> {
fn device_ptr_mut<'a>(
&'a mut self,
stream: &'a CudaStream,
) -> (cudarc::driver::sys::CUdeviceptr, SyncOnDrop<'a>) {
DevicePtrMut::device_ptr_mut(&mut self.inner, stream)
}
}
impl<T: cudarc::driver::DeviceRepr> TrackedCudaSlice<T> {
pub fn device_ptr(&self) -> &cudarc::driver::sys::CUdeviceptr {
&self.raw_ptr
}
pub fn device_ptr_value(&self) -> cudarc::driver::sys::CUdeviceptr {
self.raw_ptr
}
pub fn into_bytes(self) -> TrackedCudaSlice<u8> {
let this = ManuallyDrop::new(self);
let bytes = this.bytes;
let manager = Arc::clone(&this.manager);
let ptr = this.raw_ptr;
let len_bytes: usize = bytes
.try_into()
.expect("TrackedCudaSlice byte size must fit into usize");
let inner = unsafe {
manager
.device
.inner()
.upgrade_device_ptr::<u8>(ptr, len_bytes)
};
TrackedCudaSlice {
bytes,
manager,
inner,
raw_ptr: ptr,
}
}
}
impl<T: cudarc::driver::DeviceRepr> AsKernelParam for &TrackedCudaSlice<T> {
fn as_kernel_param(&self) -> *mut std::ffi::c_void {
((*self).device_ptr() as *const cudarc::driver::sys::CUdeviceptr)
.cast_mut()
.cast()
}
}
impl<T: cudarc::driver::DeviceRepr> AsKernelParam for &mut TrackedCudaSlice<T> {
fn as_kernel_param(&self) -> *mut std::ffi::c_void {
((self.device_ptr()) as *const cudarc::driver::sys::CUdeviceptr)
.cast_mut()
.cast()
}
}
impl<'a, T: cudarc::driver::DeviceRepr> IntoKernelParamStorage for &'a TrackedCudaSlice<T> {
type Storage = DeviceParamStorage<'a>;
fn into_kernel_param_storage(self) -> Self::Storage {
let (ptr, sync) = DevicePtr::device_ptr(&self.inner, self.inner.stream());
DeviceParamStorage::synced(ptr, sync)
}
}
impl<'a, T: cudarc::driver::DeviceRepr> IntoKernelParamStorage for &'a mut TrackedCudaSlice<T> {
type Storage = DeviceParamStorage<'static>;
fn into_kernel_param_storage(self) -> Self::Storage {
let stream = self.inner.stream().clone();
let (ptr, sync) = DevicePtrMut::device_ptr_mut(&mut self.inner, &stream);
std::mem::forget(sync);
DeviceParamStorage::unsynced(ptr)
}
}
impl<T: cudarc::driver::DeviceRepr> Drop for TrackedCudaSlice<T> {
fn drop(&mut self) {
self.manager.record_free(self.bytes);
}
}
impl GpuMemoryManager {
pub fn new(device: Arc<CudaDevice>, budget: MemoryBudget) -> Self {
Self {
device,
budget,
allocated: AtomicU64::new(0),
}
}
pub fn alloc<T: cudarc::driver::DeviceRepr>(
self: &Arc<Self>,
len: usize,
) -> Result<TrackedCudaSlice<T>> {
let bytes = (len as u64)
.checked_mul(std::mem::size_of::<T>() as u64)
.ok_or_else(|| XlogError::Kernel("Allocation size overflow".to_string()))?;
loop {
let current = self.allocated.load(Ordering::SeqCst);
let new_val = current.saturating_add(bytes);
if new_val > self.budget.device_bytes {
return Err(XlogError::ResourceExhausted {
context: "GPU memory allocation".to_string(),
estimated_bytes: bytes,
budget_bytes: self.budget.device_bytes,
});
}
if self
.allocated
.compare_exchange(current, new_val, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
{
break;
}
}
let slice = unsafe {
self.device.inner().alloc::<T>(len).map_err(|e| {
self.allocated.fetch_sub(bytes, Ordering::SeqCst);
XlogError::Kernel(format!("GPU allocation failed: {}", e))
})?
};
let (raw_ptr, sync) = DevicePtr::device_ptr(&slice, slice.stream());
std::mem::forget(sync);
Ok(TrackedCudaSlice {
bytes,
manager: Arc::clone(self),
inner: slice,
raw_ptr,
})
}
pub fn check_budget(&self, bytes: u64) -> Result<()> {
let current = self.allocated.load(Ordering::SeqCst);
let proposed = current.saturating_add(bytes);
if proposed > self.budget.device_bytes {
return Err(XlogError::ResourceExhausted {
context: "GPU memory allocation".to_string(),
estimated_bytes: bytes,
budget_bytes: self.budget.device_bytes,
});
}
Ok(())
}
pub fn allocated_bytes(&self) -> u64 {
self.allocated.load(Ordering::SeqCst)
}
pub fn budget(&self) -> &MemoryBudget {
&self.budget
}
pub fn device(&self) -> &Arc<CudaDevice> {
&self.device
}
pub fn record_free(&self, bytes: u64) {
self.allocated.fetch_sub(bytes, Ordering::SeqCst);
}
pub fn remaining_bytes(&self) -> u64 {
let allocated = self.allocated.load(Ordering::SeqCst);
self.budget.device_bytes.saturating_sub(allocated)
}
pub fn reset_tracking(&self) {
self.allocated.store(0, Ordering::SeqCst);
}
}
pub enum CudaColumn {
Owned(TrackedCudaSlice<u8>),
Dlpack(DlpackColumn),
ArrowDevice(ArrowDeviceColumn),
}
pub struct DlpackColumn {
ptr: cudarc::driver::sys::CUdeviceptr,
len_bytes: usize,
stream: Arc<CudaStream>,
_tensor: DlpackManagedTensor,
}
pub struct ArrowDeviceColumn {
ptr: cudarc::driver::sys::CUdeviceptr,
len_bytes: usize,
stream: Arc<CudaStream>,
_import: Arc<ArrowDeviceImport>,
}
impl CudaColumn {
pub fn owned(slice: TrackedCudaSlice<u8>) -> Self {
Self::Owned(slice)
}
pub fn dlpack(
ptr: cudarc::driver::sys::CUdeviceptr,
len_bytes: usize,
stream: Arc<CudaStream>,
tensor: DlpackManagedTensor,
) -> Self {
Self::Dlpack(DlpackColumn {
ptr,
len_bytes,
stream,
_tensor: tensor,
})
}
pub fn arrow_device(
ptr: cudarc::driver::sys::CUdeviceptr,
len_bytes: usize,
stream: Arc<CudaStream>,
import: Arc<ArrowDeviceImport>,
) -> Self {
Self::ArrowDevice(ArrowDeviceColumn {
ptr,
len_bytes,
stream,
_import: import,
})
}
pub fn stream(&self) -> &Arc<CudaStream> {
match self {
CudaColumn::Owned(slice) => slice.stream(),
CudaColumn::Dlpack(col) => &col.stream,
CudaColumn::ArrowDevice(col) => &col.stream,
}
}
pub fn device_ptr(&self) -> &cudarc::driver::sys::CUdeviceptr {
match self {
CudaColumn::Owned(slice) => slice.device_ptr(),
CudaColumn::Dlpack(col) => &col.ptr,
CudaColumn::ArrowDevice(col) => &col.ptr,
}
}
}
impl From<TrackedCudaSlice<u8>> for CudaColumn {
fn from(value: TrackedCudaSlice<u8>) -> Self {
CudaColumn::Owned(value)
}
}
impl DeviceSlice<u8> for CudaColumn {
fn len(&self) -> usize {
match self {
CudaColumn::Owned(slice) => slice.len(),
CudaColumn::Dlpack(col) => col.len_bytes,
CudaColumn::ArrowDevice(col) => col.len_bytes,
}
}
fn stream(&self) -> &Arc<CudaStream> {
self.stream()
}
}
impl DevicePtr<u8> for CudaColumn {
fn device_ptr<'a>(
&'a self,
stream: &'a CudaStream,
) -> (cudarc::driver::sys::CUdeviceptr, SyncOnDrop<'a>) {
match self {
CudaColumn::Owned(slice) => DevicePtr::device_ptr(slice, stream),
CudaColumn::Dlpack(col) => (col.ptr, SyncOnDrop::Sync(None)),
CudaColumn::ArrowDevice(col) => (col.ptr, SyncOnDrop::Sync(None)),
}
}
}
impl DevicePtrMut<u8> for CudaColumn {
fn device_ptr_mut<'a>(
&'a mut self,
stream: &'a CudaStream,
) -> (cudarc::driver::sys::CUdeviceptr, SyncOnDrop<'a>) {
match self {
CudaColumn::Owned(slice) => DevicePtrMut::device_ptr_mut(slice, stream),
CudaColumn::Dlpack(col) => (col.ptr, SyncOnDrop::Sync(None)),
CudaColumn::ArrowDevice(col) => (col.ptr, SyncOnDrop::Sync(None)),
}
}
}
impl AsKernelParam for &CudaColumn {
fn as_kernel_param(&self) -> *mut std::ffi::c_void {
((self.device_ptr()) as *const cudarc::driver::sys::CUdeviceptr)
.cast_mut()
.cast()
}
}
impl AsKernelParam for &mut CudaColumn {
fn as_kernel_param(&self) -> *mut std::ffi::c_void {
((self.device_ptr()) as *const cudarc::driver::sys::CUdeviceptr)
.cast_mut()
.cast()
}
}
impl<'a> IntoKernelParamStorage for &'a CudaColumn {
type Storage = DeviceParamStorage<'a>;
fn into_kernel_param_storage(self) -> Self::Storage {
match self {
CudaColumn::Owned(slice) => slice.into_kernel_param_storage(),
CudaColumn::Dlpack(col) => DeviceParamStorage::unsynced(col.ptr),
CudaColumn::ArrowDevice(col) => DeviceParamStorage::unsynced(col.ptr),
}
}
}
impl<'a> IntoKernelParamStorage for &'a mut CudaColumn {
type Storage = DeviceParamStorage<'a>;
fn into_kernel_param_storage(self) -> Self::Storage {
match self {
CudaColumn::Owned(slice) => slice.into_kernel_param_storage(),
CudaColumn::Dlpack(col) => DeviceParamStorage::unsynced(col.ptr),
CudaColumn::ArrowDevice(col) => DeviceParamStorage::unsynced(col.ptr),
}
}
}
pub struct CudaBuffer {
pub columns: Vec<CudaColumn>,
pub row_cap: u64,
pub d_num_rows: TrackedCudaSlice<u32>,
pub schema: Schema,
cached_row_count: AtomicU32,
}
impl CudaBuffer {
pub fn from_columns(
columns: Vec<CudaColumn>,
row_cap: u64,
d_num_rows: TrackedCudaSlice<u32>,
schema: Schema,
) -> Self {
assert_eq!(
columns.len(),
schema.arity(),
"Number of columns ({}) must match schema arity ({})",
columns.len(),
schema.arity()
);
Self {
columns,
row_cap,
d_num_rows,
schema,
cached_row_count: AtomicU32::new(u32::MAX),
}
}
pub fn from_columns_with_host_count(
columns: Vec<CudaColumn>,
row_cap: u64,
d_num_rows: TrackedCudaSlice<u32>,
schema: Schema,
host_row_count: u32,
) -> Self {
assert_eq!(
columns.len(),
schema.arity(),
"Number of columns ({}) must match schema arity ({})",
columns.len(),
schema.arity()
);
Self {
columns,
row_cap,
d_num_rows,
schema,
cached_row_count: AtomicU32::new(host_row_count),
}
}
pub fn cached_row_count(&self) -> Option<u32> {
let v = self.cached_row_count.load(Ordering::Relaxed);
if v == u32::MAX {
None
} else {
Some(v)
}
}
pub fn set_cached_row_count_if_unset(&self, count: u32) {
let _ = self.cached_row_count.compare_exchange(
u32::MAX,
count,
Ordering::Relaxed,
Ordering::Relaxed,
);
}
pub fn num_rows(&self) -> u64 {
self.row_cap
}
pub fn num_rows_device(&self) -> &TrackedCudaSlice<u32> {
&self.d_num_rows
}
pub fn is_empty(&self) -> bool {
self.row_cap == 0
}
pub fn schema(&self) -> &Schema {
&self.schema
}
pub fn arity(&self) -> usize {
self.schema.arity()
}
pub fn estimated_bytes(&self) -> u64 {
self.row_cap * self.schema.row_size_bytes() as u64
}
pub fn column(&self, index: usize) -> Option<&CudaColumn> {
self.columns.get(index)
}
}
pub fn validate_logical_row_count(row_cap: u64, logical_rows: usize) -> Result<usize> {
let row_cap_usize = usize::try_from(row_cap)
.map_err(|_| XlogError::Kernel(format!("Row capacity {} exceeds usize::MAX", row_cap)))?;
if logical_rows > row_cap_usize {
return Err(XlogError::Kernel(format!(
"Logical row count {} exceeds row capacity {}",
logical_rows, row_cap
)));
}
debug_assert!(logical_rows <= row_cap_usize);
Ok(logical_rows)
}
#[cfg(test)]
mod tests {
use super::*;
use xlog_core::ScalarType;
fn try_device() -> Option<Arc<CudaDevice>> {
match CudaDevice::new(0) {
Ok(d) => Some(Arc::new(d)),
Err(e) => {
eprintln!("Skipping test: CUDA runtime unavailable: {}", e);
None
}
}
}
#[test]
fn test_cuda_buffer_empty() {
let Some(device) = try_device() else {
return;
};
let budget = MemoryBudget::with_limit(1024 * 1024);
let manager = Arc::new(GpuMemoryManager::new(device, budget));
let mut d_num_rows = manager.alloc::<u32>(1).unwrap();
manager
.device()
.inner()
.htod_sync_copy_into(&[0u32], &mut d_num_rows)
.unwrap();
let buffer = CudaBuffer::from_columns(Vec::new(), 0, d_num_rows, Schema::new(vec![]));
assert!(buffer.is_empty());
assert_eq!(buffer.num_rows(), 0);
assert_eq!(buffer.arity(), 0);
assert_eq!(buffer.estimated_bytes(), 0);
}
#[test]
fn test_cuda_buffer_schema() {
let schema = Schema::new(vec![
("a".to_string(), ScalarType::U32),
("b".to_string(), ScalarType::U64),
]);
let Some(device) = try_device() else {
return;
};
let budget = MemoryBudget::with_limit(1024 * 1024);
let manager = Arc::new(GpuMemoryManager::new(device, budget));
let mut d_num_rows = manager.alloc::<u32>(1).unwrap();
manager
.device()
.inner()
.htod_sync_copy_into(&[100u32], &mut d_num_rows)
.unwrap();
let col_a = CudaColumn::owned(manager.alloc::<u8>(100 * 4).unwrap()); let col_b = CudaColumn::owned(manager.alloc::<u8>(100 * 8).unwrap()); let buffer = CudaBuffer::from_columns(vec![col_a, col_b], 100, d_num_rows, schema.clone());
assert_eq!(buffer.num_rows(), 100);
assert_eq!(buffer.arity(), 2);
assert_eq!(buffer.estimated_bytes(), 1200);
assert_eq!(buffer.schema(), &schema);
}
#[test]
fn test_memory_manager_creation() {
let Some(device) = try_device() else {
return;
};
let budget = MemoryBudget::with_limit(1024 * 1024); let manager = Arc::new(GpuMemoryManager::new(device, budget));
assert_eq!(manager.allocated_bytes(), 0);
assert_eq!(manager.budget().device_bytes, 1024 * 1024);
assert_eq!(manager.remaining_bytes(), 1024 * 1024);
}
#[test]
fn test_memory_manager_alloc() {
let Some(device) = try_device() else {
return;
};
let budget = MemoryBudget::with_limit(1024 * 1024); let manager = Arc::new(GpuMemoryManager::new(device, budget));
let _slice = manager
.alloc::<u32>(256)
.expect("Allocation should succeed");
assert_eq!(manager.allocated_bytes(), 1024);
assert_eq!(manager.remaining_bytes(), 1024 * 1024 - 1024);
}
#[test]
fn test_memory_manager_budget_exceeded() {
let Some(device) = try_device() else {
return;
};
let budget = MemoryBudget::with_limit(1024); let manager = Arc::new(GpuMemoryManager::new(device, budget));
let result = manager.alloc::<u32>(512);
assert!(result.is_err());
if let Err(XlogError::ResourceExhausted {
estimated_bytes,
budget_bytes,
..
}) = result
{
assert_eq!(estimated_bytes, 2048);
assert_eq!(budget_bytes, 1024);
} else {
panic!("Expected ResourceExhausted error");
}
}
#[test]
fn test_memory_manager_check_budget() {
let Some(device) = try_device() else {
return;
};
let budget = MemoryBudget::with_limit(1000);
let manager = Arc::new(GpuMemoryManager::new(device, budget));
assert!(manager.check_budget(500).is_ok());
assert!(manager.check_budget(1001).is_err());
}
#[test]
fn test_memory_manager_multiple_allocs() {
let Some(device) = try_device() else {
return;
};
let budget = MemoryBudget::with_limit(4096); let manager = Arc::new(GpuMemoryManager::new(device, budget));
let _slice1 = manager
.alloc::<u32>(256)
.expect("First allocation should succeed");
assert_eq!(manager.allocated_bytes(), 1024);
let _slice2 = manager
.alloc::<u32>(256)
.expect("Second allocation should succeed");
assert_eq!(manager.allocated_bytes(), 2048);
let result = manager.alloc::<u32>(1024); assert!(result.is_err());
assert_eq!(manager.allocated_bytes(), 2048);
}
#[test]
fn test_memory_manager_record_free() {
let Some(device) = try_device() else {
return;
};
let budget = MemoryBudget::with_limit(4096);
let manager = Arc::new(GpuMemoryManager::new(device, budget));
let slice = manager
.alloc::<u32>(256)
.expect("Allocation should succeed");
assert_eq!(manager.allocated_bytes(), 1024);
drop(slice);
assert_eq!(manager.allocated_bytes(), 0);
assert_eq!(manager.remaining_bytes(), 4096);
}
#[test]
fn test_cuda_buffer_from_columns() {
let Some(device) = try_device() else {
return;
};
let budget = MemoryBudget::with_limit(1024 * 1024);
let manager = Arc::new(GpuMemoryManager::new(device, budget));
let schema = Schema::new(vec![
("col1".to_string(), ScalarType::U32),
("col2".to_string(), ScalarType::U32),
]);
let col1 = manager.alloc::<u8>(400).expect("Alloc col1");
let col2 = manager.alloc::<u8>(400).expect("Alloc col2");
let mut d_num_rows = manager.alloc::<u32>(1).expect("Alloc row count");
manager
.device()
.inner()
.htod_sync_copy_into(&[100u32], &mut d_num_rows)
.expect("Upload row count");
let buffer =
CudaBuffer::from_columns(vec![col1.into(), col2.into()], 100, d_num_rows, schema);
assert_eq!(buffer.num_rows(), 100);
assert_eq!(buffer.arity(), 2);
assert!(!buffer.is_empty());
assert!(buffer.column(0).is_some());
assert!(buffer.column(1).is_some());
assert!(buffer.column(2).is_none());
}
#[test]
fn test_cuda_buffer_from_columns_mismatch() {
let schema = Schema::new(vec![
("col1".to_string(), ScalarType::U32),
("col2".to_string(), ScalarType::U32),
]);
let Some(device) = try_device() else {
return;
};
let budget = MemoryBudget::with_limit(1024 * 1024);
let manager = Arc::new(GpuMemoryManager::new(device, budget));
let mut d_num_rows = manager.alloc::<u32>(1).expect("Alloc row count");
manager
.device()
.inner()
.htod_sync_copy_into(&[100u32], &mut d_num_rows)
.expect("Upload row count");
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
CudaBuffer::from_columns(vec![], 100, d_num_rows, schema);
}));
assert!(
result.is_err(),
"Expected from_columns to panic on schema mismatch"
);
}
}