use std::alloc::{Layout, alloc, dealloc};
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use anyhow::{Result, anyhow};
use ronn_core::{DataType, MemoryInfo, MemoryType, TensorAllocator, TensorBuffer};
#[derive(Debug)]
pub struct BitNetMemoryAllocator {
allocated_bytes: Arc<AtomicUsize>,
peak_bytes: Arc<AtomicUsize>,
total_bytes: usize,
}
impl BitNetMemoryAllocator {
pub fn new() -> Self {
let total_bytes = 8 * 1024 * 1024 * 1024;
Self {
allocated_bytes: Arc::new(AtomicUsize::new(0)),
peak_bytes: Arc::new(AtomicUsize::new(0)),
total_bytes,
}
}
fn calculate_packed_size(&self, shape: &[usize], dtype: DataType) -> Result<usize> {
let element_count: usize = shape.iter().product();
match dtype {
DataType::Bool => {
Ok((element_count + 7) / 8) }
DataType::U8 => {
Ok((element_count + 3) / 4) }
DataType::F32 | DataType::F16 | DataType::BF16 => {
let element_size = match dtype {
DataType::F32 => std::mem::size_of::<f32>(),
DataType::F16 | DataType::BF16 => std::mem::size_of::<u16>(),
_ => unreachable!(),
};
Ok(element_count * element_size)
}
_ => Err(anyhow!(
"Unsupported data type for BitNet allocator: {:?}",
dtype
)),
}
}
fn get_alignment(&self, dtype: DataType) -> usize {
match dtype {
DataType::Bool | DataType::U8 => 32, DataType::F32 => 32, DataType::F16 | DataType::BF16 => 16, _ => 8, }
}
fn update_stats(&self, size: usize, allocating: bool) {
if allocating {
let new_allocated = self.allocated_bytes.fetch_add(size, Ordering::Relaxed) + size;
let current_peak = self.peak_bytes.load(Ordering::Relaxed);
if new_allocated > current_peak {
self.peak_bytes.store(new_allocated, Ordering::Relaxed);
}
} else {
self.allocated_bytes.fetch_sub(size, Ordering::Relaxed);
}
}
}
impl Default for BitNetMemoryAllocator {
fn default() -> Self {
Self::new()
}
}
impl TensorAllocator for BitNetMemoryAllocator {
fn allocate(&self, shape: &[usize], dtype: DataType) -> Result<TensorBuffer> {
let size = self.calculate_packed_size(shape, dtype)?;
let alignment = self.get_alignment(dtype);
let layout = Layout::from_size_align(size, alignment)
.map_err(|e| anyhow!("Invalid memory layout: {}", e))?;
let ptr = unsafe {
let raw_ptr = alloc(layout);
if raw_ptr.is_null() {
return Err(anyhow!("Memory allocation failed for {} bytes", size));
}
raw_ptr
};
self.update_stats(size, true);
Ok(TensorBuffer {
ptr,
size,
alignment,
memory_type: MemoryType::SystemRAM,
})
}
fn deallocate(&self, buffer: TensorBuffer) -> Result<()> {
let layout = Layout::from_size_align(buffer.size, buffer.alignment)
.map_err(|e| anyhow!("Invalid memory layout for deallocation: {}", e))?;
unsafe {
dealloc(buffer.ptr, layout);
}
self.update_stats(buffer.size, false);
Ok(())
}
fn get_memory_info(&self) -> MemoryInfo {
MemoryInfo {
total_bytes: self.total_bytes,
allocated_bytes: self.allocated_bytes.load(Ordering::Relaxed),
peak_bytes: self.peak_bytes.load(Ordering::Relaxed),
}
}
}
pub fn create_bitnet_allocator() -> Arc<dyn TensorAllocator> {
Arc::new(BitNetMemoryAllocator::new())
}
#[derive(Debug, Clone)]
pub struct BitNetMemoryStats {
pub binary_tensor_count: usize,
pub ternary_tensor_count: usize,
pub memory_saved_bytes: usize,
pub average_compression_ratio: f32,
}
impl BitNetMemoryStats {
pub fn new() -> Self {
Self {
binary_tensor_count: 0,
ternary_tensor_count: 0,
memory_saved_bytes: 0,
average_compression_ratio: 1.0,
}
}
pub fn add_binary_tensor(&mut self, original_size: usize, compressed_size: usize) {
self.binary_tensor_count += 1;
self.memory_saved_bytes += original_size.saturating_sub(compressed_size);
self.recalculate_compression_ratio();
}
pub fn add_ternary_tensor(&mut self, original_size: usize, compressed_size: usize) {
self.ternary_tensor_count += 1;
self.memory_saved_bytes += original_size.saturating_sub(compressed_size);
self.recalculate_compression_ratio();
}
fn recalculate_compression_ratio(&mut self) {
let total_tensors = self.binary_tensor_count + self.ternary_tensor_count;
if total_tensors > 0 {
let binary_contribution = self.binary_tensor_count as f32 * 32.0; let ternary_contribution = self.ternary_tensor_count as f32 * 16.0;
self.average_compression_ratio =
(binary_contribution + ternary_contribution) / total_tensors as f32;
}
}
}
impl Default for BitNetMemoryStats {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bitnet_allocator_creation() {
let allocator = BitNetMemoryAllocator::new();
let memory_info = allocator.get_memory_info();
assert_eq!(memory_info.allocated_bytes, 0);
assert_eq!(memory_info.peak_bytes, 0);
assert!(memory_info.total_bytes > 0);
}
#[test]
fn test_binary_tensor_allocation() -> Result<()> {
let allocator = BitNetMemoryAllocator::new();
let buffer = allocator.allocate(&[1000], DataType::Bool)?;
assert_eq!(buffer.size, 125); assert_eq!(buffer.alignment, 32); assert_eq!(buffer.memory_type, MemoryType::SystemRAM);
let memory_info = allocator.get_memory_info();
assert_eq!(memory_info.allocated_bytes, 125);
assert_eq!(memory_info.peak_bytes, 125);
allocator.deallocate(buffer)?;
let memory_info_after = allocator.get_memory_info();
assert_eq!(memory_info_after.allocated_bytes, 0);
assert_eq!(memory_info_after.peak_bytes, 125);
Ok(())
}
#[test]
fn test_ternary_tensor_allocation() -> Result<()> {
let allocator = BitNetMemoryAllocator::new();
let buffer = allocator.allocate(&[100], DataType::U8)?;
assert_eq!(buffer.size, 25); assert_eq!(buffer.alignment, 32);
allocator.deallocate(buffer)?;
Ok(())
}
#[test]
fn test_memory_stats() {
let mut stats = BitNetMemoryStats::new();
stats.add_binary_tensor(4000, 125); stats.add_ternary_tensor(1600, 100);
assert_eq!(stats.binary_tensor_count, 1);
assert_eq!(stats.ternary_tensor_count, 1);
assert_eq!(stats.memory_saved_bytes, 4000 - 125 + 1600 - 100);
assert!(stats.average_compression_ratio > 20.0); }
#[test]
fn test_multi_dimensional_allocation() -> Result<()> {
let allocator = BitNetMemoryAllocator::new();
let buffer = allocator.allocate(&[32, 32], DataType::Bool)?;
assert_eq!(buffer.size, 128);
allocator.deallocate(buffer)?;
Ok(())
}
}