Skip to main content

oxibonsai_kernels/
aligned.rs

1//! Cache-line aligned memory allocations for SIMD kernel operations.
2//!
3//! Standard `Vec<f32>` alignment (4 bytes) is insufficient for optimal SIMD
4//! loads and stores which require 32-byte (AVX) or 64-byte (AVX-512/cache line)
5//! alignment. This module provides aligned buffer types that guarantee 64-byte
6//! alignment for all allocations.
7
8use std::alloc::Layout;
9
10use oxibonsai_core::tensor::BlockQ1_0G128;
11
12/// Alignment in bytes for all aligned allocations (cache line size).
13pub const ALIGNMENT: usize = 64;
14
15/// A cache-line aligned buffer of `f32` values.
16///
17/// Guarantees that the backing memory starts at a 64-byte boundary,
18/// which is optimal for AVX-512 aligned loads and cache line prefetch.
19///
20/// # Example
21///
22/// ```
23/// use oxibonsai_kernels::aligned::AlignedBuffer;
24///
25/// let buf = AlignedBuffer::new(256);
26/// assert_eq!(buf.len(), 256);
27/// assert_eq!(buf.as_ptr() as usize % 64, 0);
28/// ```
29pub struct AlignedBuffer {
30    /// Raw pointer to the aligned allocation.
31    ptr: *mut f32,
32    /// Number of f32 elements.
33    len: usize,
34    /// Layout used for deallocation.
35    layout: Layout,
36}
37
38// SAFETY: The buffer owns its allocation exclusively.
39unsafe impl Send for AlignedBuffer {}
40// SAFETY: Shared references to the buffer are read-only.
41unsafe impl Sync for AlignedBuffer {}
42
43impl AlignedBuffer {
44    /// Allocate a new zero-initialized aligned buffer of `len` f32 elements.
45    ///
46    /// The returned buffer is guaranteed to have 64-byte alignment.
47    /// A zero-length buffer produces a valid (dangling but aligned) pointer.
48    pub fn new(len: usize) -> Self {
49        if len == 0 {
50            return Self {
51                ptr: ALIGNMENT as *mut f32, // aligned dangling pointer
52                len: 0,
53                layout: Layout::from_size_align(0, ALIGNMENT)
54                    .expect("zero-size layout should always be valid"),
55            };
56        }
57
58        let byte_size = len * std::mem::size_of::<f32>();
59        let layout = Layout::from_size_align(byte_size, ALIGNMENT)
60            .expect("layout should be valid for reasonable buffer sizes");
61
62        // SAFETY: layout has non-zero size and valid alignment.
63        let ptr = unsafe { std::alloc::alloc_zeroed(layout) };
64        if ptr.is_null() {
65            std::alloc::handle_alloc_error(layout);
66        }
67
68        Self {
69            ptr: ptr.cast::<f32>(),
70            len,
71            layout,
72        }
73    }
74
75    /// Returns the number of f32 elements in this buffer.
76    #[inline]
77    pub fn len(&self) -> usize {
78        self.len
79    }
80
81    /// Returns `true` if the buffer contains no elements.
82    #[inline]
83    pub fn is_empty(&self) -> bool {
84        self.len == 0
85    }
86
87    /// Returns the raw pointer to the start of the buffer.
88    #[inline]
89    pub fn as_ptr(&self) -> *const f32 {
90        self.ptr
91    }
92
93    /// Returns a mutable raw pointer to the start of the buffer.
94    #[inline]
95    pub fn as_mut_ptr(&mut self) -> *mut f32 {
96        self.ptr
97    }
98
99    /// View the buffer as an immutable slice.
100    #[inline]
101    pub fn as_slice(&self) -> &[f32] {
102        if self.len == 0 {
103            return &[];
104        }
105        // SAFETY: ptr is valid for len elements, properly aligned, and initialized to zero.
106        unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
107    }
108
109    /// View the buffer as a mutable slice.
110    #[inline]
111    pub fn as_mut_slice(&mut self) -> &mut [f32] {
112        if self.len == 0 {
113            return &mut [];
114        }
115        // SAFETY: ptr is valid for len elements, properly aligned, and we have exclusive access.
116        unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
117    }
118
119    /// Copy data from a slice into this buffer.
120    ///
121    /// Panics if `src.len() > self.len()`.
122    pub fn copy_from_slice(&mut self, src: &[f32]) {
123        assert!(
124            src.len() <= self.len,
125            "source slice length ({}) exceeds buffer length ({})",
126            src.len(),
127            self.len
128        );
129        self.as_mut_slice()[..src.len()].copy_from_slice(src);
130    }
131}
132
133impl Drop for AlignedBuffer {
134    fn drop(&mut self) {
135        if self.len > 0 {
136            // SAFETY: ptr was allocated with this layout in `new`.
137            unsafe {
138                std::alloc::dealloc(self.ptr.cast::<u8>(), self.layout);
139            }
140        }
141    }
142}
143
144impl std::fmt::Debug for AlignedBuffer {
145    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146        f.debug_struct("AlignedBuffer")
147            .field("len", &self.len)
148            .field("alignment", &ALIGNMENT)
149            .field("aligned", &(self.as_ptr() as usize % ALIGNMENT == 0))
150            .finish()
151    }
152}
153
154/// A cache-line aligned buffer of `BlockQ1_0G128` values.
155///
156/// Same alignment guarantees as [`AlignedBuffer`] but for quantized weight blocks.
157pub struct AlignedBlocks {
158    /// Raw pointer to the aligned allocation.
159    ptr: *mut BlockQ1_0G128,
160    /// Number of block elements.
161    len: usize,
162    /// Layout used for deallocation.
163    layout: Layout,
164}
165
166// SAFETY: The buffer owns its allocation exclusively.
167unsafe impl Send for AlignedBlocks {}
168// SAFETY: Shared references to the buffer are read-only.
169unsafe impl Sync for AlignedBlocks {}
170
171impl AlignedBlocks {
172    /// Allocate a new zero-initialized aligned buffer of `len` blocks.
173    pub fn new(len: usize) -> Self {
174        if len == 0 {
175            return Self {
176                ptr: ALIGNMENT as *mut BlockQ1_0G128,
177                len: 0,
178                layout: Layout::from_size_align(0, ALIGNMENT)
179                    .expect("zero-size layout should always be valid"),
180            };
181        }
182
183        let byte_size = len * std::mem::size_of::<BlockQ1_0G128>();
184        let layout = Layout::from_size_align(byte_size, ALIGNMENT)
185            .expect("layout should be valid for reasonable buffer sizes");
186
187        // SAFETY: layout has non-zero size and valid alignment.
188        let ptr = unsafe { std::alloc::alloc_zeroed(layout) };
189        if ptr.is_null() {
190            std::alloc::handle_alloc_error(layout);
191        }
192
193        Self {
194            ptr: ptr.cast::<BlockQ1_0G128>(),
195            len,
196            layout,
197        }
198    }
199
200    /// Returns the number of blocks.
201    #[inline]
202    pub fn len(&self) -> usize {
203        self.len
204    }
205
206    /// Returns `true` if the buffer contains no blocks.
207    #[inline]
208    pub fn is_empty(&self) -> bool {
209        self.len == 0
210    }
211
212    /// Returns the raw pointer to the start of the buffer.
213    #[inline]
214    pub fn as_ptr(&self) -> *const BlockQ1_0G128 {
215        self.ptr
216    }
217
218    /// View the buffer as an immutable slice.
219    #[inline]
220    pub fn as_slice(&self) -> &[BlockQ1_0G128] {
221        if self.len == 0 {
222            return &[];
223        }
224        // SAFETY: ptr is valid for len elements, properly aligned, and zero-initialized.
225        unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
226    }
227
228    /// View the buffer as a mutable slice.
229    #[inline]
230    pub fn as_mut_slice(&mut self) -> &mut [BlockQ1_0G128] {
231        if self.len == 0 {
232            return &mut [];
233        }
234        // SAFETY: ptr is valid for len elements, properly aligned, and we have exclusive access.
235        unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
236    }
237}
238
239impl Drop for AlignedBlocks {
240    fn drop(&mut self) {
241        if self.len > 0 {
242            // SAFETY: ptr was allocated with this layout in `new`.
243            unsafe {
244                std::alloc::dealloc(self.ptr.cast::<u8>(), self.layout);
245            }
246        }
247    }
248}
249
250impl std::fmt::Debug for AlignedBlocks {
251    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252        f.debug_struct("AlignedBlocks")
253            .field("len", &self.len)
254            .field("alignment", &ALIGNMENT)
255            .finish()
256    }
257}
258
259/// Split a slice at cache-line boundaries.
260///
261/// Returns `(prefix, aligned_middle, suffix)` where `aligned_middle` starts
262/// at a 64-byte aligned address and has a 64-byte aligned length (in bytes).
263///
264/// If the input is already aligned, `prefix` will be empty.
265/// If the input is too short to contain any aligned portion, the entire
266/// slice is returned as `prefix` with empty `aligned_middle` and `suffix`.
267pub fn align_to_cache_line(data: &[f32]) -> (&[f32], &[f32], &[f32]) {
268    if data.is_empty() {
269        return (&[], &[], &[]);
270    }
271
272    let ptr = data.as_ptr() as usize;
273    let f32_size = std::mem::size_of::<f32>();
274
275    // How many bytes past the last alignment boundary?
276    let misalign_bytes = ptr % ALIGNMENT;
277
278    // Number of f32s to skip to reach alignment
279    let prefix_len = if misalign_bytes == 0 {
280        0
281    } else {
282        let skip_bytes = ALIGNMENT - misalign_bytes;
283        // Round up to whole f32s
284        skip_bytes.div_ceil(f32_size)
285    };
286
287    if prefix_len >= data.len() {
288        // Entire slice is in the prefix — no aligned middle
289        return (data, &[], &[]);
290    }
291
292    let remaining = data.len() - prefix_len;
293
294    // How many f32s fit in a cache-line-aligned chunk?
295    let f32s_per_line = ALIGNMENT / f32_size; // 16
296    let aligned_len = (remaining / f32s_per_line) * f32s_per_line;
297
298    let prefix = &data[..prefix_len];
299    let aligned = &data[prefix_len..prefix_len + aligned_len];
300    let suffix = &data[prefix_len + aligned_len..];
301
302    (prefix, aligned, suffix)
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    #[test]
310    fn aligned_buffer_new_and_access() {
311        let buf = AlignedBuffer::new(128);
312        assert_eq!(buf.len(), 128);
313        assert!(!buf.is_empty());
314        // All zeros
315        for &v in buf.as_slice() {
316            assert!((v - 0.0).abs() < f32::EPSILON);
317        }
318    }
319
320    #[test]
321    fn aligned_buffer_alignment() {
322        let buf = AlignedBuffer::new(256);
323        let ptr_val = buf.as_ptr() as usize;
324        assert_eq!(
325            ptr_val % ALIGNMENT,
326            0,
327            "buffer pointer {ptr_val:#x} is not 64-byte aligned"
328        );
329    }
330
331    #[test]
332    fn aligned_buffer_zero_length() {
333        let buf = AlignedBuffer::new(0);
334        assert_eq!(buf.len(), 0);
335        assert!(buf.is_empty());
336        assert_eq!(buf.as_slice().len(), 0);
337    }
338
339    #[test]
340    fn aligned_buffer_large() {
341        let buf = AlignedBuffer::new(10_000);
342        assert_eq!(buf.len(), 10_000);
343        assert_eq!(buf.as_ptr() as usize % ALIGNMENT, 0);
344    }
345
346    #[test]
347    fn aligned_buffer_copy_from_slice() {
348        let mut buf = AlignedBuffer::new(8);
349        let src = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
350        buf.copy_from_slice(&src);
351        assert_eq!(buf.as_slice(), &src);
352    }
353
354    #[test]
355    fn aligned_buffer_mut_slice() {
356        let mut buf = AlignedBuffer::new(4);
357        {
358            let s = buf.as_mut_slice();
359            s[0] = 42.0;
360            s[3] = -1.0;
361        }
362        assert!((buf.as_slice()[0] - 42.0).abs() < f32::EPSILON);
363        assert!((buf.as_slice()[3] - (-1.0)).abs() < f32::EPSILON);
364    }
365
366    #[test]
367    fn aligned_blocks_new_and_access() {
368        let blocks = AlignedBlocks::new(16);
369        assert_eq!(blocks.len(), 16);
370        assert!(!blocks.is_empty());
371        assert_eq!(blocks.as_ptr() as usize % ALIGNMENT, 0);
372    }
373
374    #[test]
375    fn aligned_blocks_zero_length() {
376        let blocks = AlignedBlocks::new(0);
377        assert_eq!(blocks.len(), 0);
378        assert!(blocks.is_empty());
379        assert_eq!(blocks.as_slice().len(), 0);
380    }
381
382    #[test]
383    fn align_to_cache_line_empty() {
384        let data: &[f32] = &[];
385        let (prefix, aligned, suffix) = align_to_cache_line(data);
386        assert!(prefix.is_empty());
387        assert!(aligned.is_empty());
388        assert!(suffix.is_empty());
389    }
390
391    #[test]
392    fn align_to_cache_line_already_aligned() {
393        let buf = AlignedBuffer::new(64);
394        let data = buf.as_slice();
395        let (prefix, aligned, suffix) = align_to_cache_line(data);
396        // Already aligned, so prefix should be empty
397        assert!(
398            prefix.is_empty(),
399            "prefix should be empty for aligned buffer"
400        );
401        assert_eq!(aligned.len() + suffix.len(), data.len());
402    }
403
404    #[test]
405    fn align_to_cache_line_preserves_data() {
406        let buf = AlignedBuffer::new(128);
407        let data = buf.as_slice();
408        let (prefix, aligned, suffix) = align_to_cache_line(data);
409        // Total length preserved
410        assert_eq!(
411            prefix.len() + aligned.len() + suffix.len(),
412            data.len(),
413            "split must preserve total length"
414        );
415    }
416
417    #[test]
418    fn aligned_buffer_debug() {
419        let buf = AlignedBuffer::new(32);
420        let dbg = format!("{buf:?}");
421        assert!(dbg.contains("AlignedBuffer"));
422        assert!(dbg.contains("32"));
423    }
424
425    #[test]
426    fn aligned_blocks_debug() {
427        let blocks = AlignedBlocks::new(8);
428        let dbg = format!("{blocks:?}");
429        assert!(dbg.contains("AlignedBlocks"));
430    }
431}