oxibonsai_kernels/
aligned.rs1use std::alloc::Layout;
9
10use oxibonsai_core::tensor::BlockQ1_0G128;
11
12pub const ALIGNMENT: usize = 64;
14
15pub struct AlignedBuffer {
30 ptr: *mut f32,
32 len: usize,
34 layout: Layout,
36}
37
38unsafe impl Send for AlignedBuffer {}
40unsafe impl Sync for AlignedBuffer {}
42
43impl AlignedBuffer {
44 pub fn new(len: usize) -> Self {
49 if len == 0 {
50 return Self {
51 ptr: ALIGNMENT as *mut f32, len: 0,
53 layout: Layout::from_size_align(0, ALIGNMENT)
54 .expect("zero-size layout should always be valid"),
55 };
56 }
57
58 let byte_size = len * std::mem::size_of::<f32>();
59 let layout = Layout::from_size_align(byte_size, ALIGNMENT)
60 .expect("layout should be valid for reasonable buffer sizes");
61
62 let ptr = unsafe { std::alloc::alloc_zeroed(layout) };
64 if ptr.is_null() {
65 std::alloc::handle_alloc_error(layout);
66 }
67
68 Self {
69 ptr: ptr.cast::<f32>(),
70 len,
71 layout,
72 }
73 }
74
75 #[inline]
77 pub fn len(&self) -> usize {
78 self.len
79 }
80
81 #[inline]
83 pub fn is_empty(&self) -> bool {
84 self.len == 0
85 }
86
87 #[inline]
89 pub fn as_ptr(&self) -> *const f32 {
90 self.ptr
91 }
92
93 #[inline]
95 pub fn as_mut_ptr(&mut self) -> *mut f32 {
96 self.ptr
97 }
98
99 #[inline]
101 pub fn as_slice(&self) -> &[f32] {
102 if self.len == 0 {
103 return &[];
104 }
105 unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
107 }
108
109 #[inline]
111 pub fn as_mut_slice(&mut self) -> &mut [f32] {
112 if self.len == 0 {
113 return &mut [];
114 }
115 unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
117 }
118
119 pub fn copy_from_slice(&mut self, src: &[f32]) {
123 assert!(
124 src.len() <= self.len,
125 "source slice length ({}) exceeds buffer length ({})",
126 src.len(),
127 self.len
128 );
129 self.as_mut_slice()[..src.len()].copy_from_slice(src);
130 }
131}
132
133impl Drop for AlignedBuffer {
134 fn drop(&mut self) {
135 if self.len > 0 {
136 unsafe {
138 std::alloc::dealloc(self.ptr.cast::<u8>(), self.layout);
139 }
140 }
141 }
142}
143
144impl std::fmt::Debug for AlignedBuffer {
145 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146 f.debug_struct("AlignedBuffer")
147 .field("len", &self.len)
148 .field("alignment", &ALIGNMENT)
149 .field("aligned", &(self.as_ptr() as usize % ALIGNMENT == 0))
150 .finish()
151 }
152}
153
154pub struct AlignedBlocks {
158 ptr: *mut BlockQ1_0G128,
160 len: usize,
162 layout: Layout,
164}
165
166unsafe impl Send for AlignedBlocks {}
168unsafe impl Sync for AlignedBlocks {}
170
171impl AlignedBlocks {
172 pub fn new(len: usize) -> Self {
174 if len == 0 {
175 return Self {
176 ptr: ALIGNMENT as *mut BlockQ1_0G128,
177 len: 0,
178 layout: Layout::from_size_align(0, ALIGNMENT)
179 .expect("zero-size layout should always be valid"),
180 };
181 }
182
183 let byte_size = len * std::mem::size_of::<BlockQ1_0G128>();
184 let layout = Layout::from_size_align(byte_size, ALIGNMENT)
185 .expect("layout should be valid for reasonable buffer sizes");
186
187 let ptr = unsafe { std::alloc::alloc_zeroed(layout) };
189 if ptr.is_null() {
190 std::alloc::handle_alloc_error(layout);
191 }
192
193 Self {
194 ptr: ptr.cast::<BlockQ1_0G128>(),
195 len,
196 layout,
197 }
198 }
199
200 #[inline]
202 pub fn len(&self) -> usize {
203 self.len
204 }
205
206 #[inline]
208 pub fn is_empty(&self) -> bool {
209 self.len == 0
210 }
211
212 #[inline]
214 pub fn as_ptr(&self) -> *const BlockQ1_0G128 {
215 self.ptr
216 }
217
218 #[inline]
220 pub fn as_slice(&self) -> &[BlockQ1_0G128] {
221 if self.len == 0 {
222 return &[];
223 }
224 unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
226 }
227
228 #[inline]
230 pub fn as_mut_slice(&mut self) -> &mut [BlockQ1_0G128] {
231 if self.len == 0 {
232 return &mut [];
233 }
234 unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
236 }
237}
238
239impl Drop for AlignedBlocks {
240 fn drop(&mut self) {
241 if self.len > 0 {
242 unsafe {
244 std::alloc::dealloc(self.ptr.cast::<u8>(), self.layout);
245 }
246 }
247 }
248}
249
250impl std::fmt::Debug for AlignedBlocks {
251 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252 f.debug_struct("AlignedBlocks")
253 .field("len", &self.len)
254 .field("alignment", &ALIGNMENT)
255 .finish()
256 }
257}
258
259pub fn align_to_cache_line(data: &[f32]) -> (&[f32], &[f32], &[f32]) {
268 if data.is_empty() {
269 return (&[], &[], &[]);
270 }
271
272 let ptr = data.as_ptr() as usize;
273 let f32_size = std::mem::size_of::<f32>();
274
275 let misalign_bytes = ptr % ALIGNMENT;
277
278 let prefix_len = if misalign_bytes == 0 {
280 0
281 } else {
282 let skip_bytes = ALIGNMENT - misalign_bytes;
283 skip_bytes.div_ceil(f32_size)
285 };
286
287 if prefix_len >= data.len() {
288 return (data, &[], &[]);
290 }
291
292 let remaining = data.len() - prefix_len;
293
294 let f32s_per_line = ALIGNMENT / f32_size; let aligned_len = (remaining / f32s_per_line) * f32s_per_line;
297
298 let prefix = &data[..prefix_len];
299 let aligned = &data[prefix_len..prefix_len + aligned_len];
300 let suffix = &data[prefix_len + aligned_len..];
301
302 (prefix, aligned, suffix)
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308
309 #[test]
310 fn aligned_buffer_new_and_access() {
311 let buf = AlignedBuffer::new(128);
312 assert_eq!(buf.len(), 128);
313 assert!(!buf.is_empty());
314 for &v in buf.as_slice() {
316 assert!((v - 0.0).abs() < f32::EPSILON);
317 }
318 }
319
320 #[test]
321 fn aligned_buffer_alignment() {
322 let buf = AlignedBuffer::new(256);
323 let ptr_val = buf.as_ptr() as usize;
324 assert_eq!(
325 ptr_val % ALIGNMENT,
326 0,
327 "buffer pointer {ptr_val:#x} is not 64-byte aligned"
328 );
329 }
330
331 #[test]
332 fn aligned_buffer_zero_length() {
333 let buf = AlignedBuffer::new(0);
334 assert_eq!(buf.len(), 0);
335 assert!(buf.is_empty());
336 assert_eq!(buf.as_slice().len(), 0);
337 }
338
339 #[test]
340 fn aligned_buffer_large() {
341 let buf = AlignedBuffer::new(10_000);
342 assert_eq!(buf.len(), 10_000);
343 assert_eq!(buf.as_ptr() as usize % ALIGNMENT, 0);
344 }
345
346 #[test]
347 fn aligned_buffer_copy_from_slice() {
348 let mut buf = AlignedBuffer::new(8);
349 let src = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
350 buf.copy_from_slice(&src);
351 assert_eq!(buf.as_slice(), &src);
352 }
353
354 #[test]
355 fn aligned_buffer_mut_slice() {
356 let mut buf = AlignedBuffer::new(4);
357 {
358 let s = buf.as_mut_slice();
359 s[0] = 42.0;
360 s[3] = -1.0;
361 }
362 assert!((buf.as_slice()[0] - 42.0).abs() < f32::EPSILON);
363 assert!((buf.as_slice()[3] - (-1.0)).abs() < f32::EPSILON);
364 }
365
366 #[test]
367 fn aligned_blocks_new_and_access() {
368 let blocks = AlignedBlocks::new(16);
369 assert_eq!(blocks.len(), 16);
370 assert!(!blocks.is_empty());
371 assert_eq!(blocks.as_ptr() as usize % ALIGNMENT, 0);
372 }
373
374 #[test]
375 fn aligned_blocks_zero_length() {
376 let blocks = AlignedBlocks::new(0);
377 assert_eq!(blocks.len(), 0);
378 assert!(blocks.is_empty());
379 assert_eq!(blocks.as_slice().len(), 0);
380 }
381
382 #[test]
383 fn align_to_cache_line_empty() {
384 let data: &[f32] = &[];
385 let (prefix, aligned, suffix) = align_to_cache_line(data);
386 assert!(prefix.is_empty());
387 assert!(aligned.is_empty());
388 assert!(suffix.is_empty());
389 }
390
391 #[test]
392 fn align_to_cache_line_already_aligned() {
393 let buf = AlignedBuffer::new(64);
394 let data = buf.as_slice();
395 let (prefix, aligned, suffix) = align_to_cache_line(data);
396 assert!(
398 prefix.is_empty(),
399 "prefix should be empty for aligned buffer"
400 );
401 assert_eq!(aligned.len() + suffix.len(), data.len());
402 }
403
404 #[test]
405 fn align_to_cache_line_preserves_data() {
406 let buf = AlignedBuffer::new(128);
407 let data = buf.as_slice();
408 let (prefix, aligned, suffix) = align_to_cache_line(data);
409 assert_eq!(
411 prefix.len() + aligned.len() + suffix.len(),
412 data.len(),
413 "split must preserve total length"
414 );
415 }
416
417 #[test]
418 fn aligned_buffer_debug() {
419 let buf = AlignedBuffer::new(32);
420 let dbg = format!("{buf:?}");
421 assert!(dbg.contains("AlignedBuffer"));
422 assert!(dbg.contains("32"));
423 }
424
425 #[test]
426 fn aligned_blocks_debug() {
427 let blocks = AlignedBlocks::new(8);
428 let dbg = format!("{blocks:?}");
429 assert!(dbg.contains("AlignedBlocks"));
430 }
431}