oxiblas_core/memory/alloc.rs
1//! Memory management utilities for OxiBLAS.
2//!
3//! This module provides:
4//! - Aligned memory allocation
5//! - Stack-based temporary allocation (StackReq pattern)
6//! - Cache-aware data layout utilities
7//! - Prefetch hints for cache optimization
8//! - Memory pool for temporary allocations
9//! - Custom allocator support via the `Alloc` trait
10
11use core::alloc::Layout;
12use core::mem::size_of;
13use std::alloc::{alloc, alloc_zeroed, dealloc};
14
15// =============================================================================
16// Custom Allocator Trait
17// =============================================================================
18
19/// A stable-Rust compatible allocator trait.
20///
21/// This trait provides a simplified interface for custom memory allocators,
22/// similar to `std::alloc::Allocator` but available on stable Rust.
23///
24/// # Safety
25///
26/// Implementations must ensure that:
27/// - `allocate` returns a valid pointer or null on failure
28/// - `deallocate` is called with the same layout used for allocation
29/// - Memory is properly aligned as specified by the layout
30pub unsafe trait Alloc: Clone {
31 /// Allocates memory with the specified layout.
32 ///
33 /// Returns a pointer to the allocated memory, or null on failure.
34 fn allocate(&self, layout: Layout) -> *mut u8;
35
36 /// Allocates zero-initialized memory with the specified layout.
37 ///
38 /// Returns a pointer to the allocated memory, or null on failure.
39 fn allocate_zeroed(&self, layout: Layout) -> *mut u8;
40
41 /// Deallocates memory previously allocated with `allocate`.
42 ///
43 /// # Safety
44 ///
45 /// The caller must ensure that:
46 /// - `ptr` was previously returned by `allocate` or `allocate_zeroed`
47 /// - `layout` matches the layout used for allocation
48 /// - The memory has not already been deallocated
49 unsafe fn deallocate(&self, ptr: *mut u8, layout: Layout);
50}
51
52/// The global allocator.
53///
54/// This is the default allocator used by `AlignedVec` and other types.
55/// It delegates to the Rust global allocator (`std::alloc::alloc`).
56#[derive(Debug, Clone, Copy, Default)]
57pub struct Global;
58
59// SAFETY: Global allocator delegates to the Rust global allocator
60// which is guaranteed to be safe.
61unsafe impl Alloc for Global {
62 #[inline]
63 fn allocate(&self, layout: Layout) -> *mut u8 {
64 if layout.size() == 0 {
65 // Return a properly aligned dangling pointer for ZSTs
66 return layout.align() as *mut u8;
67 }
68 unsafe { alloc(layout) }
69 }
70
71 #[inline]
72 fn allocate_zeroed(&self, layout: Layout) -> *mut u8 {
73 if layout.size() == 0 {
74 return layout.align() as *mut u8;
75 }
76 unsafe { alloc_zeroed(layout) }
77 }
78
79 #[inline]
80 unsafe fn deallocate(&self, ptr: *mut u8, layout: Layout) {
81 if layout.size() != 0 {
82 dealloc(ptr, layout);
83 }
84 }
85}
86
87// =============================================================================
88// Prefetch hints
89// =============================================================================
90
91/// Prefetch locality hint levels.
92///
93/// These correspond to the x86 prefetch instruction locality hints.
94/// Higher locality means the data is expected to be accessed more times
95/// and should be kept in closer cache levels.
96#[derive(Debug, Clone, Copy, PartialEq, Eq)]
97pub enum PrefetchLocality {
98 /// Non-temporal: Data will be accessed once and then not reused.
99 /// Minimizes cache pollution.
100 NonTemporal,
101 /// Low locality: Data will be accessed a few times.
102 /// Kept in L3 cache.
103 Low,
104 /// Medium locality: Data will be accessed several times.
105 /// Kept in L2 cache.
106 Medium,
107 /// High locality: Data will be accessed many times.
108 /// Kept in L1 cache.
109 High,
110}
111
112/// Prefetches data for reading.
113///
114/// This is a hint to the processor that data at the given address will be
115/// read in the near future. The processor may choose to load the data into
116/// cache ahead of time.
117///
118/// # Arguments
119/// * `ptr` - Pointer to the data to prefetch
120/// * `locality` - Hint about how often the data will be reused
121///
122/// # Safety
123/// The pointer doesn't need to be valid - invalid prefetches are simply ignored.
124/// However, prefetching to invalid addresses may have performance implications.
125#[inline]
126pub fn prefetch_read<T>(ptr: *const T, locality: PrefetchLocality) {
127 #[cfg(target_arch = "x86_64")]
128 {
129 use core::arch::x86_64::*;
130 unsafe {
131 match locality {
132 PrefetchLocality::NonTemporal => _mm_prefetch(ptr.cast(), _MM_HINT_NTA),
133 PrefetchLocality::Low => _mm_prefetch(ptr.cast(), _MM_HINT_T2),
134 PrefetchLocality::Medium => _mm_prefetch(ptr.cast(), _MM_HINT_T1),
135 PrefetchLocality::High => _mm_prefetch(ptr.cast(), _MM_HINT_T0),
136 }
137 }
138 }
139
140 #[cfg(target_arch = "aarch64")]
141 {
142 // ARM NEON prefetch using inline assembly
143 // PRFM instruction with PLDL1KEEP, PLDL2KEEP, PLDL3KEEP
144 unsafe {
145 match locality {
146 PrefetchLocality::NonTemporal | PrefetchLocality::Low => {
147 core::arch::asm!(
148 "prfm pldl3keep, [{0}]",
149 in(reg) ptr,
150 options(nostack, preserves_flags)
151 );
152 }
153 PrefetchLocality::Medium => {
154 core::arch::asm!(
155 "prfm pldl2keep, [{0}]",
156 in(reg) ptr,
157 options(nostack, preserves_flags)
158 );
159 }
160 PrefetchLocality::High => {
161 core::arch::asm!(
162 "prfm pldl1keep, [{0}]",
163 in(reg) ptr,
164 options(nostack, preserves_flags)
165 );
166 }
167 }
168 }
169 }
170
171 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
172 {
173 // No prefetch support - silently ignore
174 let _ = (ptr, locality);
175 }
176}
177
178/// Prefetches data for writing.
179///
180/// This is a hint to the processor that data at the given address will be
181/// written in the near future. This is useful for write-allocate cache policies.
182///
183/// # Arguments
184/// * `ptr` - Pointer to the data to prefetch
185/// * `locality` - Hint about how often the data will be reused
186///
187/// # Safety
188/// The pointer doesn't need to be valid - invalid prefetches are simply ignored.
189#[inline]
190pub fn prefetch_write<T>(ptr: *mut T, locality: PrefetchLocality) {
191 #[cfg(target_arch = "x86_64")]
192 {
193 use core::arch::x86_64::*;
194 unsafe {
195 // x86 doesn't have separate prefetch-for-write in SSE/AVX
196 // PREFETCHW is available with 3DNow! or PREFETCHWT1 with newer CPUs
197 // Fall back to regular prefetch
198 match locality {
199 PrefetchLocality::NonTemporal => _mm_prefetch(ptr.cast(), _MM_HINT_NTA),
200 PrefetchLocality::Low => _mm_prefetch(ptr.cast(), _MM_HINT_T2),
201 PrefetchLocality::Medium => _mm_prefetch(ptr.cast(), _MM_HINT_T1),
202 PrefetchLocality::High => _mm_prefetch(ptr.cast(), _MM_HINT_T0),
203 }
204 }
205 }
206
207 #[cfg(target_arch = "aarch64")]
208 {
209 // ARM NEON prefetch for store using PSTL1KEEP, PSTL2KEEP, PSTL3KEEP
210 unsafe {
211 match locality {
212 PrefetchLocality::NonTemporal | PrefetchLocality::Low => {
213 core::arch::asm!(
214 "prfm pstl3keep, [{0}]",
215 in(reg) ptr,
216 options(nostack, preserves_flags)
217 );
218 }
219 PrefetchLocality::Medium => {
220 core::arch::asm!(
221 "prfm pstl2keep, [{0}]",
222 in(reg) ptr,
223 options(nostack, preserves_flags)
224 );
225 }
226 PrefetchLocality::High => {
227 core::arch::asm!(
228 "prfm pstl1keep, [{0}]",
229 in(reg) ptr,
230 options(nostack, preserves_flags)
231 );
232 }
233 }
234 }
235 }
236
237 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
238 {
239 let _ = (ptr, locality);
240 }
241}
242
243/// Prefetches a range of memory for reading.
244///
245/// This prefetches multiple cache lines starting at `ptr` and covering
246/// `count` elements of type `T`.
247///
248/// # Arguments
249/// * `ptr` - Pointer to the start of the data
250/// * `count` - Number of elements to prefetch
251/// * `locality` - Hint about data reuse
252#[inline]
253pub fn prefetch_read_range<T>(ptr: *const T, count: usize, locality: PrefetchLocality) {
254 let bytes = count * size_of::<T>();
255 let num_lines = bytes.div_ceil(CACHE_LINE_SIZE);
256
257 for i in 0..num_lines {
258 let offset = i * CACHE_LINE_SIZE;
259 prefetch_read(unsafe { (ptr as *const u8).add(offset) }, locality);
260 }
261}
262
263/// Prefetches a range of memory for writing.
264#[inline]
265pub fn prefetch_write_range<T>(ptr: *mut T, count: usize, locality: PrefetchLocality) {
266 let bytes = count * size_of::<T>();
267 let num_lines = bytes.div_ceil(CACHE_LINE_SIZE);
268
269 for i in 0..num_lines {
270 let offset = i * CACHE_LINE_SIZE;
271 prefetch_write(unsafe { (ptr as *mut u8).add(offset) }, locality);
272 }
273}
274
275/// Prefetch distance calculator for streaming access patterns.
276///
277/// This calculates the optimal prefetch distance (in elements) based on
278/// memory bandwidth and latency estimates.
279#[derive(Debug, Clone, Copy)]
280pub struct PrefetchDistance {
281 /// Number of cache lines to prefetch ahead
282 pub lines_ahead: usize,
283}
284
285impl Default for PrefetchDistance {
286 fn default() -> Self {
287 // Default: prefetch 8 cache lines ahead
288 // This works well for most streaming workloads
289 Self { lines_ahead: 8 }
290 }
291}
292
293impl PrefetchDistance {
294 /// Creates a new prefetch distance calculator.
295 pub const fn new(lines_ahead: usize) -> Self {
296 Self { lines_ahead }
297 }
298
299 /// Calculates the prefetch offset in bytes.
300 #[inline]
301 pub const fn offset_bytes(&self) -> usize {
302 self.lines_ahead * CACHE_LINE_SIZE
303 }
304
305 /// Calculates the prefetch offset in elements.
306 #[inline]
307 pub const fn offset_elements<T>(&self) -> usize {
308 self.offset_bytes() / size_of::<T>()
309 }
310}
311
312/// Default alignment for SIMD operations (cache line size).
313///
314/// Apple Silicon (M1/M2/M3) uses 128-byte cache lines for optimal performance.
315/// x86_64 and most other architectures use 64-byte cache lines.
316#[cfg(target_arch = "aarch64")]
317pub const CACHE_LINE_SIZE: usize = 128;
318
319/// Cache line size for the target architecture.
320///
321/// x86_64 and most other architectures use 64-byte cache lines.
322#[cfg(not(target_arch = "aarch64"))]
323pub const CACHE_LINE_SIZE: usize = 64;
324
325/// Default alignment for matrices.
326///
327/// This is set to match the cache line size for optimal memory access patterns.
328/// - Apple Silicon (aarch64): 128 bytes
329/// - x86_64 and others: 64 bytes
330pub const DEFAULT_ALIGN: usize = CACHE_LINE_SIZE;
331
332/// Computes the aligned size for a given element count and type.
333#[inline]
334pub const fn aligned_size<T>(count: usize, align: usize) -> usize {
335 let size = count * size_of::<T>();
336 (size + align - 1) & !(align - 1)
337}
338
339/// Computes the number of elements that fit in a given number of bytes with alignment.
340#[inline]
341pub const fn elements_per_aligned_bytes<T>(bytes: usize) -> usize {
342 bytes / size_of::<T>()
343}
344
345/// Rounds up to the next multiple of a power of 2.
346#[inline]
347pub const fn round_up_pow2(value: usize, align: usize) -> usize {
348 debug_assert!(align.is_power_of_two());
349 (value + align - 1) & !(align - 1)
350}
351
352/// Checks if a pointer is aligned to the given alignment.
353#[inline]
354pub fn is_aligned<T>(ptr: *const T, align: usize) -> bool {
355 (ptr as usize) % align == 0
356}