oxiblas_core/memory/alloc.rs
1//! Memory management utilities for OxiBLAS.
2//!
3//! This module provides:
4//! - Aligned memory allocation
5//! - Stack-based temporary allocation (StackReq pattern)
6//! - Cache-aware data layout utilities
7//! - Prefetch hints for cache optimization
8//! - Memory pool for temporary allocations
9//! - Custom allocator support via the `Alloc` trait
10
11use core::alloc::Layout;
12use core::mem::size_of;
13
14#[cfg(not(feature = "std"))]
15use alloc::alloc::{alloc, alloc_zeroed, dealloc};
16#[cfg(feature = "std")]
17use std::alloc::{alloc, alloc_zeroed, dealloc};
18
19// =============================================================================
20// Custom Allocator Trait
21// =============================================================================
22
23/// A stable-Rust compatible allocator trait.
24///
25/// This trait provides a simplified interface for custom memory allocators,
26/// similar to `std::alloc::Allocator` but available on stable Rust.
27///
28/// # Safety
29///
30/// Implementations must ensure that:
31/// - `allocate` returns a valid pointer or null on failure
32/// - `deallocate` is called with the same layout used for allocation
33/// - Memory is properly aligned as specified by the layout
34pub unsafe trait Alloc: Clone {
35 /// Allocates memory with the specified layout.
36 ///
37 /// Returns a pointer to the allocated memory, or null on failure.
38 fn allocate(&self, layout: Layout) -> *mut u8;
39
40 /// Allocates zero-initialized memory with the specified layout.
41 ///
42 /// Returns a pointer to the allocated memory, or null on failure.
43 fn allocate_zeroed(&self, layout: Layout) -> *mut u8;
44
45 /// Deallocates memory previously allocated with `allocate`.
46 ///
47 /// # Safety
48 ///
49 /// The caller must ensure that:
50 /// - `ptr` was previously returned by `allocate` or `allocate_zeroed`
51 /// - `layout` matches the layout used for allocation
52 /// - The memory has not already been deallocated
53 unsafe fn deallocate(&self, ptr: *mut u8, layout: Layout);
54}
55
56/// The global allocator.
57///
58/// This is the default allocator used by `AlignedVec` and other types.
59/// It delegates to the Rust global allocator (`std::alloc::alloc`).
60#[derive(Debug, Clone, Copy, Default)]
61pub struct Global;
62
63// SAFETY: Global allocator delegates to the Rust global allocator
64// which is guaranteed to be safe.
65unsafe impl Alloc for Global {
66 #[inline]
67 fn allocate(&self, layout: Layout) -> *mut u8 {
68 if layout.size() == 0 {
69 // Return a properly aligned dangling pointer for ZSTs
70 return layout.align() as *mut u8;
71 }
72 unsafe { alloc(layout) }
73 }
74
75 #[inline]
76 fn allocate_zeroed(&self, layout: Layout) -> *mut u8 {
77 if layout.size() == 0 {
78 return layout.align() as *mut u8;
79 }
80 unsafe { alloc_zeroed(layout) }
81 }
82
83 #[inline]
84 unsafe fn deallocate(&self, ptr: *mut u8, layout: Layout) {
85 if layout.size() != 0 {
86 dealloc(ptr, layout);
87 }
88 }
89}
90
91// =============================================================================
92// Prefetch hints
93// =============================================================================
94
95/// Prefetch locality hint levels.
96///
97/// These correspond to the x86 prefetch instruction locality hints.
98/// Higher locality means the data is expected to be accessed more times
99/// and should be kept in closer cache levels.
100#[derive(Debug, Clone, Copy, PartialEq, Eq)]
101pub enum PrefetchLocality {
102 /// Non-temporal: Data will be accessed once and then not reused.
103 /// Minimizes cache pollution.
104 NonTemporal,
105 /// Low locality: Data will be accessed a few times.
106 /// Kept in L3 cache.
107 Low,
108 /// Medium locality: Data will be accessed several times.
109 /// Kept in L2 cache.
110 Medium,
111 /// High locality: Data will be accessed many times.
112 /// Kept in L1 cache.
113 High,
114}
115
116/// Prefetches data for reading.
117///
118/// This is a hint to the processor that data at the given address will be
119/// read in the near future. The processor may choose to load the data into
120/// cache ahead of time.
121///
122/// # Arguments
123/// * `ptr` - Pointer to the data to prefetch
124/// * `locality` - Hint about how often the data will be reused
125///
126/// # Safety
127/// The pointer doesn't need to be valid - invalid prefetches are simply ignored.
128/// However, prefetching to invalid addresses may have performance implications.
129#[inline]
130pub fn prefetch_read<T>(ptr: *const T, locality: PrefetchLocality) {
131 #[cfg(target_arch = "x86_64")]
132 {
133 use core::arch::x86_64::*;
134 unsafe {
135 match locality {
136 PrefetchLocality::NonTemporal => _mm_prefetch(ptr.cast(), _MM_HINT_NTA),
137 PrefetchLocality::Low => _mm_prefetch(ptr.cast(), _MM_HINT_T2),
138 PrefetchLocality::Medium => _mm_prefetch(ptr.cast(), _MM_HINT_T1),
139 PrefetchLocality::High => _mm_prefetch(ptr.cast(), _MM_HINT_T0),
140 }
141 }
142 }
143
144 #[cfg(target_arch = "aarch64")]
145 {
146 // ARM NEON prefetch using inline assembly
147 // PRFM instruction with PLDL1KEEP, PLDL2KEEP, PLDL3KEEP
148 unsafe {
149 match locality {
150 PrefetchLocality::NonTemporal | PrefetchLocality::Low => {
151 core::arch::asm!(
152 "prfm pldl3keep, [{0}]",
153 in(reg) ptr,
154 options(nostack, preserves_flags)
155 );
156 }
157 PrefetchLocality::Medium => {
158 core::arch::asm!(
159 "prfm pldl2keep, [{0}]",
160 in(reg) ptr,
161 options(nostack, preserves_flags)
162 );
163 }
164 PrefetchLocality::High => {
165 core::arch::asm!(
166 "prfm pldl1keep, [{0}]",
167 in(reg) ptr,
168 options(nostack, preserves_flags)
169 );
170 }
171 }
172 }
173 }
174
175 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
176 {
177 // No prefetch support - silently ignore
178 let _ = (ptr, locality);
179 }
180}
181
182/// Prefetches data for writing.
183///
184/// This is a hint to the processor that data at the given address will be
185/// written in the near future. This is useful for write-allocate cache policies.
186///
187/// # Arguments
188/// * `ptr` - Pointer to the data to prefetch
189/// * `locality` - Hint about how often the data will be reused
190///
191/// # Safety
192/// The pointer doesn't need to be valid - invalid prefetches are simply ignored.
193#[inline]
194pub fn prefetch_write<T>(ptr: *mut T, locality: PrefetchLocality) {
195 #[cfg(target_arch = "x86_64")]
196 {
197 use core::arch::x86_64::*;
198 unsafe {
199 // x86 doesn't have separate prefetch-for-write in SSE/AVX
200 // PREFETCHW is available with 3DNow! or PREFETCHWT1 with newer CPUs
201 // Fall back to regular prefetch
202 match locality {
203 PrefetchLocality::NonTemporal => _mm_prefetch(ptr.cast(), _MM_HINT_NTA),
204 PrefetchLocality::Low => _mm_prefetch(ptr.cast(), _MM_HINT_T2),
205 PrefetchLocality::Medium => _mm_prefetch(ptr.cast(), _MM_HINT_T1),
206 PrefetchLocality::High => _mm_prefetch(ptr.cast(), _MM_HINT_T0),
207 }
208 }
209 }
210
211 #[cfg(target_arch = "aarch64")]
212 {
213 // ARM NEON prefetch for store using PSTL1KEEP, PSTL2KEEP, PSTL3KEEP
214 unsafe {
215 match locality {
216 PrefetchLocality::NonTemporal | PrefetchLocality::Low => {
217 core::arch::asm!(
218 "prfm pstl3keep, [{0}]",
219 in(reg) ptr,
220 options(nostack, preserves_flags)
221 );
222 }
223 PrefetchLocality::Medium => {
224 core::arch::asm!(
225 "prfm pstl2keep, [{0}]",
226 in(reg) ptr,
227 options(nostack, preserves_flags)
228 );
229 }
230 PrefetchLocality::High => {
231 core::arch::asm!(
232 "prfm pstl1keep, [{0}]",
233 in(reg) ptr,
234 options(nostack, preserves_flags)
235 );
236 }
237 }
238 }
239 }
240
241 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
242 {
243 let _ = (ptr, locality);
244 }
245}
246
247/// Prefetches a range of memory for reading.
248///
249/// This prefetches multiple cache lines starting at `ptr` and covering
250/// `count` elements of type `T`.
251///
252/// # Arguments
253/// * `ptr` - Pointer to the start of the data
254/// * `count` - Number of elements to prefetch
255/// * `locality` - Hint about data reuse
256#[inline]
257pub fn prefetch_read_range<T>(ptr: *const T, count: usize, locality: PrefetchLocality) {
258 let bytes = count * size_of::<T>();
259 let num_lines = bytes.div_ceil(CACHE_LINE_SIZE);
260
261 for i in 0..num_lines {
262 let offset = i * CACHE_LINE_SIZE;
263 prefetch_read(unsafe { (ptr as *const u8).add(offset) }, locality);
264 }
265}
266
267/// Prefetches a range of memory for writing.
268#[inline]
269pub fn prefetch_write_range<T>(ptr: *mut T, count: usize, locality: PrefetchLocality) {
270 let bytes = count * size_of::<T>();
271 let num_lines = bytes.div_ceil(CACHE_LINE_SIZE);
272
273 for i in 0..num_lines {
274 let offset = i * CACHE_LINE_SIZE;
275 prefetch_write(unsafe { (ptr as *mut u8).add(offset) }, locality);
276 }
277}
278
279/// Prefetch distance calculator for streaming access patterns.
280///
281/// This calculates the optimal prefetch distance (in elements) based on
282/// memory bandwidth and latency estimates.
283#[derive(Debug, Clone, Copy)]
284pub struct PrefetchDistance {
285 /// Number of cache lines to prefetch ahead
286 pub lines_ahead: usize,
287}
288
289impl Default for PrefetchDistance {
290 fn default() -> Self {
291 // Default: prefetch 8 cache lines ahead
292 // This works well for most streaming workloads
293 Self { lines_ahead: 8 }
294 }
295}
296
297impl PrefetchDistance {
298 /// Creates a new prefetch distance calculator.
299 pub const fn new(lines_ahead: usize) -> Self {
300 Self { lines_ahead }
301 }
302
303 /// Calculates the prefetch offset in bytes.
304 #[inline]
305 pub const fn offset_bytes(&self) -> usize {
306 self.lines_ahead * CACHE_LINE_SIZE
307 }
308
309 /// Calculates the prefetch offset in elements.
310 #[inline]
311 pub const fn offset_elements<T>(&self) -> usize {
312 self.offset_bytes() / size_of::<T>()
313 }
314}
315
316/// Default alignment for SIMD operations (cache line size).
317///
318/// Apple Silicon (M1/M2/M3) uses 128-byte cache lines for optimal performance.
319/// x86_64 and most other architectures use 64-byte cache lines.
320#[cfg(target_arch = "aarch64")]
321pub const CACHE_LINE_SIZE: usize = 128;
322
323/// Cache line size for the target architecture.
324///
325/// x86_64 and most other architectures use 64-byte cache lines.
326#[cfg(not(target_arch = "aarch64"))]
327pub const CACHE_LINE_SIZE: usize = 64;
328
329/// Default alignment for matrices.
330///
331/// This is set to match the cache line size for optimal memory access patterns.
332/// - Apple Silicon (aarch64): 128 bytes
333/// - x86_64 and others: 64 bytes
334pub const DEFAULT_ALIGN: usize = CACHE_LINE_SIZE;
335
336/// Computes the aligned size for a given element count and type.
337#[inline]
338pub const fn aligned_size<T>(count: usize, align: usize) -> usize {
339 let size = count * size_of::<T>();
340 (size + align - 1) & !(align - 1)
341}
342
343/// Computes the number of elements that fit in a given number of bytes with alignment.
344#[inline]
345pub const fn elements_per_aligned_bytes<T>(bytes: usize) -> usize {
346 bytes / size_of::<T>()
347}
348
349/// Rounds up to the next multiple of a power of 2.
350#[inline]
351pub const fn round_up_pow2(value: usize, align: usize) -> usize {
352 debug_assert!(align.is_power_of_two());
353 (value + align - 1) & !(align - 1)
354}
355
356/// Checks if a pointer is aligned to the given alignment.
357#[inline]
358pub fn is_aligned<T>(ptr: *const T, align: usize) -> bool {
359 (ptr as usize) % align == 0
360}