use std::alloc::Layout;
use oxibonsai_core::tensor::BlockQ1_0G128;
pub const ALIGNMENT: usize = 64;
pub struct AlignedBuffer {
ptr: *mut f32,
len: usize,
layout: Layout,
}
unsafe impl Send for AlignedBuffer {}
unsafe impl Sync for AlignedBuffer {}
impl AlignedBuffer {
pub fn new(len: usize) -> Self {
if len == 0 {
return Self {
ptr: ALIGNMENT as *mut f32, len: 0,
layout: Layout::from_size_align(0, ALIGNMENT)
.expect("zero-size layout should always be valid"),
};
}
let byte_size = len * std::mem::size_of::<f32>();
let layout = Layout::from_size_align(byte_size, ALIGNMENT)
.expect("layout should be valid for reasonable buffer sizes");
let ptr = unsafe { std::alloc::alloc_zeroed(layout) };
if ptr.is_null() {
std::alloc::handle_alloc_error(layout);
}
Self {
ptr: ptr.cast::<f32>(),
len,
layout,
}
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn as_ptr(&self) -> *const f32 {
self.ptr
}
#[inline]
pub fn as_mut_ptr(&mut self) -> *mut f32 {
self.ptr
}
#[inline]
pub fn as_slice(&self) -> &[f32] {
if self.len == 0 {
return &[];
}
unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [f32] {
if self.len == 0 {
return &mut [];
}
unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
}
pub fn copy_from_slice(&mut self, src: &[f32]) {
assert!(
src.len() <= self.len,
"source slice length ({}) exceeds buffer length ({})",
src.len(),
self.len
);
self.as_mut_slice()[..src.len()].copy_from_slice(src);
}
}
impl Drop for AlignedBuffer {
fn drop(&mut self) {
if self.len > 0 {
unsafe {
std::alloc::dealloc(self.ptr.cast::<u8>(), self.layout);
}
}
}
}
impl std::fmt::Debug for AlignedBuffer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AlignedBuffer")
.field("len", &self.len)
.field("alignment", &ALIGNMENT)
.field("aligned", &(self.as_ptr() as usize % ALIGNMENT == 0))
.finish()
}
}
pub struct AlignedBlocks {
ptr: *mut BlockQ1_0G128,
len: usize,
layout: Layout,
}
unsafe impl Send for AlignedBlocks {}
unsafe impl Sync for AlignedBlocks {}
impl AlignedBlocks {
pub fn new(len: usize) -> Self {
if len == 0 {
return Self {
ptr: ALIGNMENT as *mut BlockQ1_0G128,
len: 0,
layout: Layout::from_size_align(0, ALIGNMENT)
.expect("zero-size layout should always be valid"),
};
}
let byte_size = len * std::mem::size_of::<BlockQ1_0G128>();
let layout = Layout::from_size_align(byte_size, ALIGNMENT)
.expect("layout should be valid for reasonable buffer sizes");
let ptr = unsafe { std::alloc::alloc_zeroed(layout) };
if ptr.is_null() {
std::alloc::handle_alloc_error(layout);
}
Self {
ptr: ptr.cast::<BlockQ1_0G128>(),
len,
layout,
}
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn as_ptr(&self) -> *const BlockQ1_0G128 {
self.ptr
}
#[inline]
pub fn as_slice(&self) -> &[BlockQ1_0G128] {
if self.len == 0 {
return &[];
}
unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [BlockQ1_0G128] {
if self.len == 0 {
return &mut [];
}
unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
}
}
impl Drop for AlignedBlocks {
fn drop(&mut self) {
if self.len > 0 {
unsafe {
std::alloc::dealloc(self.ptr.cast::<u8>(), self.layout);
}
}
}
}
impl std::fmt::Debug for AlignedBlocks {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AlignedBlocks")
.field("len", &self.len)
.field("alignment", &ALIGNMENT)
.finish()
}
}
pub fn align_to_cache_line(data: &[f32]) -> (&[f32], &[f32], &[f32]) {
if data.is_empty() {
return (&[], &[], &[]);
}
let ptr = data.as_ptr() as usize;
let f32_size = std::mem::size_of::<f32>();
let misalign_bytes = ptr % ALIGNMENT;
let prefix_len = if misalign_bytes == 0 {
0
} else {
let skip_bytes = ALIGNMENT - misalign_bytes;
skip_bytes.div_ceil(f32_size)
};
if prefix_len >= data.len() {
return (data, &[], &[]);
}
let remaining = data.len() - prefix_len;
let f32s_per_line = ALIGNMENT / f32_size; let aligned_len = (remaining / f32s_per_line) * f32s_per_line;
let prefix = &data[..prefix_len];
let aligned = &data[prefix_len..prefix_len + aligned_len];
let suffix = &data[prefix_len + aligned_len..];
(prefix, aligned, suffix)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn aligned_buffer_new_and_access() {
let buf = AlignedBuffer::new(128);
assert_eq!(buf.len(), 128);
assert!(!buf.is_empty());
for &v in buf.as_slice() {
assert!((v - 0.0).abs() < f32::EPSILON);
}
}
#[test]
fn aligned_buffer_alignment() {
let buf = AlignedBuffer::new(256);
let ptr_val = buf.as_ptr() as usize;
assert_eq!(
ptr_val % ALIGNMENT,
0,
"buffer pointer {ptr_val:#x} is not 64-byte aligned"
);
}
#[test]
fn aligned_buffer_zero_length() {
let buf = AlignedBuffer::new(0);
assert_eq!(buf.len(), 0);
assert!(buf.is_empty());
assert_eq!(buf.as_slice().len(), 0);
}
#[test]
fn aligned_buffer_large() {
let buf = AlignedBuffer::new(10_000);
assert_eq!(buf.len(), 10_000);
assert_eq!(buf.as_ptr() as usize % ALIGNMENT, 0);
}
#[test]
fn aligned_buffer_copy_from_slice() {
let mut buf = AlignedBuffer::new(8);
let src = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
buf.copy_from_slice(&src);
assert_eq!(buf.as_slice(), &src);
}
#[test]
fn aligned_buffer_mut_slice() {
let mut buf = AlignedBuffer::new(4);
{
let s = buf.as_mut_slice();
s[0] = 42.0;
s[3] = -1.0;
}
assert!((buf.as_slice()[0] - 42.0).abs() < f32::EPSILON);
assert!((buf.as_slice()[3] - (-1.0)).abs() < f32::EPSILON);
}
#[test]
fn aligned_blocks_new_and_access() {
let blocks = AlignedBlocks::new(16);
assert_eq!(blocks.len(), 16);
assert!(!blocks.is_empty());
assert_eq!(blocks.as_ptr() as usize % ALIGNMENT, 0);
}
#[test]
fn aligned_blocks_zero_length() {
let blocks = AlignedBlocks::new(0);
assert_eq!(blocks.len(), 0);
assert!(blocks.is_empty());
assert_eq!(blocks.as_slice().len(), 0);
}
#[test]
fn align_to_cache_line_empty() {
let data: &[f32] = &[];
let (prefix, aligned, suffix) = align_to_cache_line(data);
assert!(prefix.is_empty());
assert!(aligned.is_empty());
assert!(suffix.is_empty());
}
#[test]
fn align_to_cache_line_already_aligned() {
let buf = AlignedBuffer::new(64);
let data = buf.as_slice();
let (prefix, aligned, suffix) = align_to_cache_line(data);
assert!(
prefix.is_empty(),
"prefix should be empty for aligned buffer"
);
assert_eq!(aligned.len() + suffix.len(), data.len());
}
#[test]
fn align_to_cache_line_preserves_data() {
let buf = AlignedBuffer::new(128);
let data = buf.as_slice();
let (prefix, aligned, suffix) = align_to_cache_line(data);
assert_eq!(
prefix.len() + aligned.len() + suffix.len(),
data.len(),
"split must preserve total length"
);
}
#[test]
fn aligned_buffer_debug() {
let buf = AlignedBuffer::new(32);
let dbg = format!("{buf:?}");
assert!(dbg.contains("AlignedBuffer"));
assert!(dbg.contains("32"));
}
#[test]
fn aligned_blocks_debug() {
let blocks = AlignedBlocks::new(8);
let dbg = format!("{blocks:?}");
assert!(dbg.contains("AlignedBlocks"));
}
}