use cubecl_common::bytes::{AllocationController, AllocationProperty};
use cubecl_core::server::IoError;
use cubecl_runtime::{
memory_management::{ManagedMemoryBinding, MemoryManagement},
storage::{BytesResource, BytesStorage},
};
pub struct CpuAllocController {
resource: BytesResource,
_binding: ManagedMemoryBinding,
}
impl AllocationController for CpuAllocController {
fn alloc_align(&self) -> usize {
align_of::<u8>()
}
fn property(&self) -> AllocationProperty {
AllocationProperty::Other
}
unsafe fn memory_mut(&mut self) -> &mut [std::mem::MaybeUninit<u8>] {
let slice = self.resource.write();
unsafe {
std::slice::from_raw_parts_mut(
slice.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>,
slice.len(),
)
}
}
fn memory(&self) -> &[std::mem::MaybeUninit<u8>] {
let slice = self.resource.read();
unsafe {
std::slice::from_raw_parts(
slice.as_ptr() as *const std::mem::MaybeUninit<u8>,
slice.len(),
)
}
}
}
impl CpuAllocController {
pub fn init(
binding: cubecl_core::server::Binding,
memory_management: &mut MemoryManagement<BytesStorage>,
) -> Result<Self, IoError> {
let memory = binding.memory.clone();
let resource = memory_management.get_resource(
binding.memory,
binding.offset_start,
binding.offset_end,
)?;
Ok(Self {
_binding: memory,
resource,
})
}
}