oxiblas_core/memory/
alloc.rs

1//! Memory management utilities for OxiBLAS.
2//!
3//! This module provides:
4//! - Aligned memory allocation
5//! - Stack-based temporary allocation (StackReq pattern)
6//! - Cache-aware data layout utilities
7//! - Prefetch hints for cache optimization
8//! - Memory pool for temporary allocations
9//! - Custom allocator support via the `Alloc` trait
10
11use core::alloc::Layout;
12use core::mem::size_of;
13use std::alloc::{alloc, alloc_zeroed, dealloc};
14
15// =============================================================================
16// Custom Allocator Trait
17// =============================================================================
18
19/// A stable-Rust compatible allocator trait.
20///
21/// This trait provides a simplified interface for custom memory allocators,
22/// similar to `std::alloc::Allocator` but available on stable Rust.
23///
24/// # Safety
25///
26/// Implementations must ensure that:
27/// - `allocate` returns a valid pointer or null on failure
28/// - `deallocate` is called with the same layout used for allocation
29/// - Memory is properly aligned as specified by the layout
30pub unsafe trait Alloc: Clone {
31    /// Allocates memory with the specified layout.
32    ///
33    /// Returns a pointer to the allocated memory, or null on failure.
34    fn allocate(&self, layout: Layout) -> *mut u8;
35
36    /// Allocates zero-initialized memory with the specified layout.
37    ///
38    /// Returns a pointer to the allocated memory, or null on failure.
39    fn allocate_zeroed(&self, layout: Layout) -> *mut u8;
40
41    /// Deallocates memory previously allocated with `allocate`.
42    ///
43    /// # Safety
44    ///
45    /// The caller must ensure that:
46    /// - `ptr` was previously returned by `allocate` or `allocate_zeroed`
47    /// - `layout` matches the layout used for allocation
48    /// - The memory has not already been deallocated
49    unsafe fn deallocate(&self, ptr: *mut u8, layout: Layout);
50}
51
52/// The global allocator.
53///
54/// This is the default allocator used by `AlignedVec` and other types.
55/// It delegates to the Rust global allocator (`std::alloc::alloc`).
56#[derive(Debug, Clone, Copy, Default)]
57pub struct Global;
58
59// SAFETY: Global allocator delegates to the Rust global allocator
60// which is guaranteed to be safe.
61unsafe impl Alloc for Global {
62    #[inline]
63    fn allocate(&self, layout: Layout) -> *mut u8 {
64        if layout.size() == 0 {
65            // Return a properly aligned dangling pointer for ZSTs
66            return layout.align() as *mut u8;
67        }
68        unsafe { alloc(layout) }
69    }
70
71    #[inline]
72    fn allocate_zeroed(&self, layout: Layout) -> *mut u8 {
73        if layout.size() == 0 {
74            return layout.align() as *mut u8;
75        }
76        unsafe { alloc_zeroed(layout) }
77    }
78
79    #[inline]
80    unsafe fn deallocate(&self, ptr: *mut u8, layout: Layout) {
81        if layout.size() != 0 {
82            dealloc(ptr, layout);
83        }
84    }
85}
86
87// =============================================================================
88// Prefetch hints
89// =============================================================================
90
91/// Prefetch locality hint levels.
92///
93/// These correspond to the x86 prefetch instruction locality hints.
94/// Higher locality means the data is expected to be accessed more times
95/// and should be kept in closer cache levels.
96#[derive(Debug, Clone, Copy, PartialEq, Eq)]
97pub enum PrefetchLocality {
98    /// Non-temporal: Data will be accessed once and then not reused.
99    /// Minimizes cache pollution.
100    NonTemporal,
101    /// Low locality: Data will be accessed a few times.
102    /// Kept in L3 cache.
103    Low,
104    /// Medium locality: Data will be accessed several times.
105    /// Kept in L2 cache.
106    Medium,
107    /// High locality: Data will be accessed many times.
108    /// Kept in L1 cache.
109    High,
110}
111
112/// Prefetches data for reading.
113///
114/// This is a hint to the processor that data at the given address will be
115/// read in the near future. The processor may choose to load the data into
116/// cache ahead of time.
117///
118/// # Arguments
119/// * `ptr` - Pointer to the data to prefetch
120/// * `locality` - Hint about how often the data will be reused
121///
122/// # Safety
123/// The pointer doesn't need to be valid - invalid prefetches are simply ignored.
124/// However, prefetching to invalid addresses may have performance implications.
125#[inline]
126pub fn prefetch_read<T>(ptr: *const T, locality: PrefetchLocality) {
127    #[cfg(target_arch = "x86_64")]
128    {
129        use core::arch::x86_64::*;
130        unsafe {
131            match locality {
132                PrefetchLocality::NonTemporal => _mm_prefetch(ptr.cast(), _MM_HINT_NTA),
133                PrefetchLocality::Low => _mm_prefetch(ptr.cast(), _MM_HINT_T2),
134                PrefetchLocality::Medium => _mm_prefetch(ptr.cast(), _MM_HINT_T1),
135                PrefetchLocality::High => _mm_prefetch(ptr.cast(), _MM_HINT_T0),
136            }
137        }
138    }
139
140    #[cfg(target_arch = "aarch64")]
141    {
142        // ARM NEON prefetch using inline assembly
143        // PRFM instruction with PLDL1KEEP, PLDL2KEEP, PLDL3KEEP
144        unsafe {
145            match locality {
146                PrefetchLocality::NonTemporal | PrefetchLocality::Low => {
147                    core::arch::asm!(
148                        "prfm pldl3keep, [{0}]",
149                        in(reg) ptr,
150                        options(nostack, preserves_flags)
151                    );
152                }
153                PrefetchLocality::Medium => {
154                    core::arch::asm!(
155                        "prfm pldl2keep, [{0}]",
156                        in(reg) ptr,
157                        options(nostack, preserves_flags)
158                    );
159                }
160                PrefetchLocality::High => {
161                    core::arch::asm!(
162                        "prfm pldl1keep, [{0}]",
163                        in(reg) ptr,
164                        options(nostack, preserves_flags)
165                    );
166                }
167            }
168        }
169    }
170
171    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
172    {
173        // No prefetch support - silently ignore
174        let _ = (ptr, locality);
175    }
176}
177
178/// Prefetches data for writing.
179///
180/// This is a hint to the processor that data at the given address will be
181/// written in the near future. This is useful for write-allocate cache policies.
182///
183/// # Arguments
184/// * `ptr` - Pointer to the data to prefetch
185/// * `locality` - Hint about how often the data will be reused
186///
187/// # Safety
188/// The pointer doesn't need to be valid - invalid prefetches are simply ignored.
189#[inline]
190pub fn prefetch_write<T>(ptr: *mut T, locality: PrefetchLocality) {
191    #[cfg(target_arch = "x86_64")]
192    {
193        use core::arch::x86_64::*;
194        unsafe {
195            // x86 doesn't have separate prefetch-for-write in SSE/AVX
196            // PREFETCHW is available with 3DNow! or PREFETCHWT1 with newer CPUs
197            // Fall back to regular prefetch
198            match locality {
199                PrefetchLocality::NonTemporal => _mm_prefetch(ptr.cast(), _MM_HINT_NTA),
200                PrefetchLocality::Low => _mm_prefetch(ptr.cast(), _MM_HINT_T2),
201                PrefetchLocality::Medium => _mm_prefetch(ptr.cast(), _MM_HINT_T1),
202                PrefetchLocality::High => _mm_prefetch(ptr.cast(), _MM_HINT_T0),
203            }
204        }
205    }
206
207    #[cfg(target_arch = "aarch64")]
208    {
209        // ARM NEON prefetch for store using PSTL1KEEP, PSTL2KEEP, PSTL3KEEP
210        unsafe {
211            match locality {
212                PrefetchLocality::NonTemporal | PrefetchLocality::Low => {
213                    core::arch::asm!(
214                        "prfm pstl3keep, [{0}]",
215                        in(reg) ptr,
216                        options(nostack, preserves_flags)
217                    );
218                }
219                PrefetchLocality::Medium => {
220                    core::arch::asm!(
221                        "prfm pstl2keep, [{0}]",
222                        in(reg) ptr,
223                        options(nostack, preserves_flags)
224                    );
225                }
226                PrefetchLocality::High => {
227                    core::arch::asm!(
228                        "prfm pstl1keep, [{0}]",
229                        in(reg) ptr,
230                        options(nostack, preserves_flags)
231                    );
232                }
233            }
234        }
235    }
236
237    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
238    {
239        let _ = (ptr, locality);
240    }
241}
242
243/// Prefetches a range of memory for reading.
244///
245/// This prefetches multiple cache lines starting at `ptr` and covering
246/// `count` elements of type `T`.
247///
248/// # Arguments
249/// * `ptr` - Pointer to the start of the data
250/// * `count` - Number of elements to prefetch
251/// * `locality` - Hint about data reuse
252#[inline]
253pub fn prefetch_read_range<T>(ptr: *const T, count: usize, locality: PrefetchLocality) {
254    let bytes = count * size_of::<T>();
255    let num_lines = bytes.div_ceil(CACHE_LINE_SIZE);
256
257    for i in 0..num_lines {
258        let offset = i * CACHE_LINE_SIZE;
259        prefetch_read(unsafe { (ptr as *const u8).add(offset) }, locality);
260    }
261}
262
263/// Prefetches a range of memory for writing.
264#[inline]
265pub fn prefetch_write_range<T>(ptr: *mut T, count: usize, locality: PrefetchLocality) {
266    let bytes = count * size_of::<T>();
267    let num_lines = bytes.div_ceil(CACHE_LINE_SIZE);
268
269    for i in 0..num_lines {
270        let offset = i * CACHE_LINE_SIZE;
271        prefetch_write(unsafe { (ptr as *mut u8).add(offset) }, locality);
272    }
273}
274
275/// Prefetch distance calculator for streaming access patterns.
276///
277/// This calculates the optimal prefetch distance (in elements) based on
278/// memory bandwidth and latency estimates.
279#[derive(Debug, Clone, Copy)]
280pub struct PrefetchDistance {
281    /// Number of cache lines to prefetch ahead
282    pub lines_ahead: usize,
283}
284
285impl Default for PrefetchDistance {
286    fn default() -> Self {
287        // Default: prefetch 8 cache lines ahead
288        // This works well for most streaming workloads
289        Self { lines_ahead: 8 }
290    }
291}
292
293impl PrefetchDistance {
294    /// Creates a new prefetch distance calculator.
295    pub const fn new(lines_ahead: usize) -> Self {
296        Self { lines_ahead }
297    }
298
299    /// Calculates the prefetch offset in bytes.
300    #[inline]
301    pub const fn offset_bytes(&self) -> usize {
302        self.lines_ahead * CACHE_LINE_SIZE
303    }
304
305    /// Calculates the prefetch offset in elements.
306    #[inline]
307    pub const fn offset_elements<T>(&self) -> usize {
308        self.offset_bytes() / size_of::<T>()
309    }
310}
311
312/// Default alignment for SIMD operations (cache line size).
313///
314/// Apple Silicon (M1/M2/M3) uses 128-byte cache lines for optimal performance.
315/// x86_64 and most other architectures use 64-byte cache lines.
316#[cfg(target_arch = "aarch64")]
317pub const CACHE_LINE_SIZE: usize = 128;
318
319/// Cache line size for the target architecture.
320///
321/// x86_64 and most other architectures use 64-byte cache lines.
322#[cfg(not(target_arch = "aarch64"))]
323pub const CACHE_LINE_SIZE: usize = 64;
324
325/// Default alignment for matrices.
326///
327/// This is set to match the cache line size for optimal memory access patterns.
328/// - Apple Silicon (aarch64): 128 bytes
329/// - x86_64 and others: 64 bytes
330pub const DEFAULT_ALIGN: usize = CACHE_LINE_SIZE;
331
332/// Computes the aligned size for a given element count and type.
333#[inline]
334pub const fn aligned_size<T>(count: usize, align: usize) -> usize {
335    let size = count * size_of::<T>();
336    (size + align - 1) & !(align - 1)
337}
338
339/// Computes the number of elements that fit in a given number of bytes with alignment.
340#[inline]
341pub const fn elements_per_aligned_bytes<T>(bytes: usize) -> usize {
342    bytes / size_of::<T>()
343}
344
345/// Rounds up to the next multiple of a power of 2.
346#[inline]
347pub const fn round_up_pow2(value: usize, align: usize) -> usize {
348    debug_assert!(align.is_power_of_two());
349    (value + align - 1) & !(align - 1)
350}
351
352/// Checks if a pointer is aligned to the given alignment.
353#[inline]
354pub fn is_aligned<T>(ptr: *const T, align: usize) -> bool {
355    (ptr as usize) % align == 0
356}