use crate::server::IoError;
use super::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
use alloc::alloc::{Layout, alloc_zeroed, dealloc};
use cubecl_common::backtrace::BackTrace;
use hashbrown::HashMap;
#[derive(Default)]
pub struct BytesStorage {
memory: HashMap<StorageId, AllocatedBytes>,
}
impl core::fmt::Debug for BytesStorage {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str("BytesStorage")
}
}
unsafe impl Send for BytesStorage {}
unsafe impl Send for BytesResource {}
#[derive(Debug)]
pub struct BytesResource {
ptr: *mut u8,
utilization: StorageUtilization,
}
struct AllocatedBytes {
ptr: *mut u8,
layout: Layout,
}
impl BytesResource {
pub fn get_write_ptr_and_length(&self) -> (*mut u8, usize) {
(
unsafe { self.ptr.add(self.utilization.offset as usize) },
self.utilization.size as usize,
)
}
pub fn write<'a>(&mut self) -> &'a mut [u8] {
let (ptr, len) = self.get_write_ptr_and_length();
unsafe { core::slice::from_raw_parts_mut(ptr, len) }
}
pub fn read<'a>(&self) -> &'a [u8] {
let (ptr, len) = self.get_write_ptr_and_length();
unsafe { core::slice::from_raw_parts(ptr, len) }
}
}
impl ComputeStorage for BytesStorage {
type Resource = BytesResource;
fn alignment(&self) -> usize {
4
}
fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
let allocated_bytes = self.memory.get(&handle.id).unwrap();
BytesResource {
ptr: allocated_bytes.ptr,
utilization: handle.utilization.clone(),
}
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(self, size))
)]
fn alloc(&mut self, size: u64) -> Result<StorageHandle, IoError> {
let id = StorageId::new();
let handle = StorageHandle {
id,
utilization: StorageUtilization { offset: 0, size },
};
if size == 0 {
let memory = AllocatedBytes {
ptr: core::ptr::NonNull::dangling().as_ptr(),
layout: Layout::new::<()>(),
};
self.memory.insert(id, memory);
} else {
unsafe {
let layout = Layout::array::<u8>(size as usize).unwrap();
let ptr = alloc_zeroed(layout);
if ptr.is_null() {
return Err(IoError::BufferTooBig {
size,
backtrace: BackTrace::capture(),
});
}
let memory = AllocatedBytes { ptr, layout };
self.memory.insert(id, memory);
}
}
Ok(handle)
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
fn dealloc(&mut self, id: StorageId) {
if let Some(memory) = self.memory.remove(&id)
&& memory.layout.size() > 0
{
unsafe {
dealloc(memory.ptr, memory.layout);
}
}
}
fn flush(&mut self) {
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test_log::test]
fn test_can_alloc_and_dealloc() {
let mut storage = BytesStorage::default();
let handle_1 = storage.alloc(64).unwrap();
assert_eq!(handle_1.size(), 64);
storage.dealloc(handle_1.id);
}
#[test_log::test]
fn test_slices() {
let mut storage = BytesStorage::default();
let handle_1 = storage.alloc(64).unwrap();
let handle_2 = StorageHandle::new(
handle_1.id,
StorageUtilization {
offset: 24,
size: 8,
},
);
storage
.get(&handle_1)
.write()
.iter_mut()
.enumerate()
.for_each(|(i, b)| {
*b = i as u8;
});
let bytes = storage.get(&handle_2).read().to_vec();
storage.dealloc(handle_1.id);
assert_eq!(bytes, &[24, 25, 26, 27, 28, 29, 30, 31]);
}
#[test_log::test]
fn test_read_after_alloc_without_write() {
let mut storage = BytesStorage::default();
let handle = storage.alloc(16).unwrap();
let resource = storage.get(&handle);
assert!(resource.read().iter().all(|&b| b == 0));
storage.dealloc(handle.id);
}
#[test_log::test]
fn test_zero_size_alloc_and_dealloc() {
let mut storage = BytesStorage::default();
let handle = storage.alloc(0).unwrap();
assert_eq!(handle.size(), 0);
storage.dealloc(handle.id);
}
#[test_log::test]
fn test_alloc_dealloc_realloc() {
let mut storage = BytesStorage::default();
let h1 = storage.alloc(32).unwrap();
storage.get(&h1).write()[0] = 0xAA;
storage.dealloc(h1.id);
let h2 = storage.alloc(32).unwrap();
storage.dealloc(h2.id);
}
#[test_log::test]
fn test_multiple_non_overlapping_regions() {
let mut storage = BytesStorage::default();
let base = storage.alloc(64).unwrap();
let regions: alloc::vec::Vec<_> = (0..4)
.map(|i| {
StorageHandle::new(
base.id,
StorageUtilization {
offset: i * 16,
size: 16,
},
)
})
.collect();
for (i, region) in regions.iter().enumerate() {
storage.get(region).write().fill(i as u8);
}
for (i, region) in regions.iter().enumerate() {
assert!(storage.get(region).read().iter().all(|&b| b == i as u8));
}
storage.dealloc(base.id);
}
}