use crate::{Result, memory_error};
use std::sync::Arc;
use std::alloc::{alloc, dealloc, Layout};
use std::ptr::NonNull;
pub struct UnifiedMemory {
ptr: NonNull<u8>,
size: usize,
layout: Layout,
}
impl UnifiedMemory {
pub fn new(size: usize) -> Result<Self> {
if size == 0 {
return Err(memory_error!("Cannot allocate zero-sized unified memory"));
}
let layout = Layout::from_size_align(size, 8)
.map_err(|e| memory_error!("Invalid layout: {}", e))?;
let ptr = unsafe { alloc(layout) };
let ptr = NonNull::new(ptr)
.ok_or_else(|| memory_error!("Failed to allocate unified memory"))?;
Ok(Self { ptr, size, layout })
}
pub fn as_ptr(&self) -> *const u8 {
self.ptr.as_ptr() as *const u8
}
pub fn as_mut_ptr(&mut self) -> *mut u8 {
self.ptr.as_ptr()
}
pub fn size(&self) -> usize {
self.size
}
pub fn copy_from_slice(&mut self, data: &[u8]) -> Result<()> {
if data.len() > self.size {
return Err(memory_error!(
"Data size {} exceeds buffer size {}",
data.len(),
self.size
));
}
unsafe {
std::ptr::copy_nonoverlapping(data.as_ptr(), self.ptr.as_ptr(), data.len());
}
Ok(())
}
pub fn copy_to_slice(&self, data: &mut [u8]) -> Result<()> {
if data.len() > self.size {
return Err(memory_error!(
"Destination size {} exceeds buffer size {}",
data.len(),
self.size
));
}
unsafe {
std::ptr::copy_nonoverlapping(self.ptr.as_ptr(), data.as_mut_ptr(), data.len());
}
Ok(())
}
}
impl Drop for UnifiedMemory {
fn drop(&mut self) {
unsafe {
dealloc(self.ptr.as_ptr(), self.layout);
}
}
}
unsafe impl Send for UnifiedMemory {}
unsafe impl Sync for UnifiedMemory {}
pub type SharedUnifiedMemory = Arc<UnifiedMemory>;
pub fn allocate_unified(size: usize) -> Result<SharedUnifiedMemory> {
Ok(Arc::new(UnifiedMemory::new(size)?))
}
pub struct ManagedMemory {
inner: UnifiedMemory,
backend_registered: bool,
}
impl ManagedMemory {
pub fn new(size: usize) -> Result<Self> {
let inner = UnifiedMemory::new(size)?;
let backend_registered = Self::try_register_with_backend(inner.as_ptr(), size);
Ok(Self {
inner,
backend_registered,
})
}
pub fn is_backend_registered(&self) -> bool {
self.backend_registered
}
pub fn as_unified(&self) -> &UnifiedMemory {
&self.inner
}
pub fn as_unified_mut(&mut self) -> &mut UnifiedMemory {
&mut self.inner
}
pub fn size(&self) -> usize {
self.inner.size()
}
pub fn copy_from_slice(&mut self, data: &[u8]) -> Result<()> {
self.inner.copy_from_slice(data)
}
pub fn copy_to_slice(&self, data: &mut [u8]) -> Result<()> {
self.inner.copy_to_slice(data)
}
pub fn prefetch_to_device(&self) -> Result<()> {
Ok(())
}
pub fn prefetch_to_host(&self) -> Result<()> {
Ok(())
}
fn try_register_with_backend(_ptr: *const u8, _size: usize) -> bool {
let backend = crate::backend::get_backend();
let caps = backend.capabilities();
caps.supports_unified_memory
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_unified_memory_allocation() {
let mem = UnifiedMemory::new(1024).unwrap();
assert_eq!(mem.size(), 1024);
}
#[test]
fn test_unified_memory_copy() {
let mut mem = UnifiedMemory::new(256).unwrap();
let data = vec![42u8; 256];
mem.copy_from_slice(&data).unwrap();
let mut output = vec![0u8; 256];
mem.copy_to_slice(&mut output).unwrap();
assert_eq!(data, output);
}
#[test]
fn test_zero_size_allocation() {
let result = UnifiedMemory::new(0);
assert!(result.is_err());
}
#[test]
fn test_managed_memory() {
let mem = ManagedMemory::new(512).unwrap();
assert_eq!(mem.size(), 512);
}
#[test]
fn test_managed_memory_copy() {
let mut mem = ManagedMemory::new(128).unwrap();
let data = vec![0xAB_u8; 128];
mem.copy_from_slice(&data).unwrap();
let mut out = vec![0u8; 128];
mem.copy_to_slice(&mut out).unwrap();
assert_eq!(data, out);
}
#[test]
fn test_managed_memory_prefetch() {
let mem = ManagedMemory::new(64).unwrap();
assert!(mem.prefetch_to_device().is_ok());
assert!(mem.prefetch_to_host().is_ok());
}
}