use crate::device::Device;
use crate::dtype::Scalar;
use crate::error::Result;
use sysinfo::System;
#[derive(Debug, Clone, Copy, Default)]
pub struct DefaultAllocator;
impl DefaultAllocator {
#[must_use]
pub const fn new() -> Self {
Self
}
#[must_use]
pub const fn device(&self) -> Device {
Device::Cpu
}
pub fn allocate<T: Scalar>(&self, count: usize) -> Result<*mut T> {
let mut vec = Vec::<T>::with_capacity(count);
let ptr = vec.as_mut_ptr();
core::mem::forget(vec);
Ok(ptr)
}
pub unsafe fn deallocate<T: Scalar>(&self, ptr: *mut T, count: usize) {
unsafe {
drop(Vec::from_raw_parts(ptr, 0, count));
}
}
pub unsafe fn copy<T: Scalar>(&self, dst: *mut T, src: *const T, count: usize) {
unsafe {
core::ptr::copy_nonoverlapping(src, dst, count);
}
}
pub unsafe fn zero<T: Scalar>(&self, ptr: *mut T, count: usize) {
unsafe {
core::ptr::write_bytes(ptr, 0, count);
}
}
#[must_use]
pub fn total_memory(&self) -> usize {
let sys = System::new_all();
sys.total_memory() as usize
}
#[must_use]
pub fn free_memory(&self) -> usize {
let sys = System::new_all();
sys.available_memory() as usize
}
}
pub trait Allocator {
fn device(&self) -> Device;
fn total_memory(&self) -> usize;
fn free_memory(&self) -> usize;
}
impl Allocator for DefaultAllocator {
fn device(&self) -> Device {
Device::Cpu
}
fn total_memory(&self) -> usize {
self.total_memory()
}
fn free_memory(&self) -> usize {
self.free_memory()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_allocator() {
let alloc = DefaultAllocator::new();
assert_eq!(alloc.device(), Device::Cpu);
assert!(alloc.total_memory() > 0);
}
#[test]
fn test_allocate_deallocate() {
let alloc = DefaultAllocator::new();
let ptr = alloc.allocate::<f32>(100).unwrap();
assert!(!ptr.is_null());
unsafe {
alloc.zero(ptr, 100);
alloc.deallocate(ptr, 100);
}
}
#[test]
fn test_copy() {
let alloc = DefaultAllocator::new();
let src = alloc.allocate::<f32>(10).unwrap();
let dst = alloc.allocate::<f32>(10).unwrap();
unsafe {
for i in 0..10 {
*src.add(i) = i as f32;
}
alloc.copy(dst, src, 10);
for i in 0..10 {
assert_eq!(*dst.add(i), i as f32);
}
alloc.deallocate(src, 10);
alloc.deallocate(dst, 10);
}
}
}