Skip to main content

oxiblas_core/memory/
alloc.rs

1//! Memory management utilities for OxiBLAS.
2//!
3//! This module provides:
4//! - Aligned memory allocation
5//! - Stack-based temporary allocation (StackReq pattern)
6//! - Cache-aware data layout utilities
7//! - Prefetch hints for cache optimization
8//! - Memory pool for temporary allocations
9//! - Custom allocator support via the `Alloc` trait
10
11use core::alloc::Layout;
12use core::mem::size_of;
13
14#[cfg(not(feature = "std"))]
15use alloc::alloc::{alloc, alloc_zeroed, dealloc};
16#[cfg(feature = "std")]
17use std::alloc::{alloc, alloc_zeroed, dealloc};
18
19// =============================================================================
20// Custom Allocator Trait
21// =============================================================================
22
23/// A stable-Rust compatible allocator trait.
24///
25/// This trait provides a simplified interface for custom memory allocators,
26/// similar to `std::alloc::Allocator` but available on stable Rust.
27///
28/// # Safety
29///
30/// Implementations must ensure that:
31/// - `allocate` returns a valid pointer or null on failure
32/// - `deallocate` is called with the same layout used for allocation
33/// - Memory is properly aligned as specified by the layout
34pub unsafe trait Alloc: Clone {
35    /// Allocates memory with the specified layout.
36    ///
37    /// Returns a pointer to the allocated memory, or null on failure.
38    fn allocate(&self, layout: Layout) -> *mut u8;
39
40    /// Allocates zero-initialized memory with the specified layout.
41    ///
42    /// Returns a pointer to the allocated memory, or null on failure.
43    fn allocate_zeroed(&self, layout: Layout) -> *mut u8;
44
45    /// Deallocates memory previously allocated with `allocate`.
46    ///
47    /// # Safety
48    ///
49    /// The caller must ensure that:
50    /// - `ptr` was previously returned by `allocate` or `allocate_zeroed`
51    /// - `layout` matches the layout used for allocation
52    /// - The memory has not already been deallocated
53    unsafe fn deallocate(&self, ptr: *mut u8, layout: Layout);
54}
55
56/// The global allocator.
57///
58/// This is the default allocator used by `AlignedVec` and other types.
59/// It delegates to the Rust global allocator (`std::alloc::alloc`).
60#[derive(Debug, Clone, Copy, Default)]
61pub struct Global;
62
63// SAFETY: Global allocator delegates to the Rust global allocator
64// which is guaranteed to be safe.
65unsafe impl Alloc for Global {
66    #[inline]
67    fn allocate(&self, layout: Layout) -> *mut u8 {
68        if layout.size() == 0 {
69            // Return a properly aligned dangling pointer for ZSTs
70            return layout.align() as *mut u8;
71        }
72        unsafe { alloc(layout) }
73    }
74
75    #[inline]
76    fn allocate_zeroed(&self, layout: Layout) -> *mut u8 {
77        if layout.size() == 0 {
78            return layout.align() as *mut u8;
79        }
80        unsafe { alloc_zeroed(layout) }
81    }
82
83    #[inline]
84    unsafe fn deallocate(&self, ptr: *mut u8, layout: Layout) {
85        if layout.size() != 0 {
86            dealloc(ptr, layout);
87        }
88    }
89}
90
91// =============================================================================
92// Prefetch hints
93// =============================================================================
94
95/// Prefetch locality hint levels.
96///
97/// These correspond to the x86 prefetch instruction locality hints.
98/// Higher locality means the data is expected to be accessed more times
99/// and should be kept in closer cache levels.
100#[derive(Debug, Clone, Copy, PartialEq, Eq)]
101pub enum PrefetchLocality {
102    /// Non-temporal: Data will be accessed once and then not reused.
103    /// Minimizes cache pollution.
104    NonTemporal,
105    /// Low locality: Data will be accessed a few times.
106    /// Kept in L3 cache.
107    Low,
108    /// Medium locality: Data will be accessed several times.
109    /// Kept in L2 cache.
110    Medium,
111    /// High locality: Data will be accessed many times.
112    /// Kept in L1 cache.
113    High,
114}
115
116/// Prefetches data for reading.
117///
118/// This is a hint to the processor that data at the given address will be
119/// read in the near future. The processor may choose to load the data into
120/// cache ahead of time.
121///
122/// # Arguments
123/// * `ptr` - Pointer to the data to prefetch
124/// * `locality` - Hint about how often the data will be reused
125///
126/// # Safety
127/// The pointer doesn't need to be valid - invalid prefetches are simply ignored.
128/// However, prefetching to invalid addresses may have performance implications.
129#[inline]
130pub fn prefetch_read<T>(ptr: *const T, locality: PrefetchLocality) {
131    #[cfg(target_arch = "x86_64")]
132    {
133        use core::arch::x86_64::*;
134        unsafe {
135            match locality {
136                PrefetchLocality::NonTemporal => _mm_prefetch(ptr.cast(), _MM_HINT_NTA),
137                PrefetchLocality::Low => _mm_prefetch(ptr.cast(), _MM_HINT_T2),
138                PrefetchLocality::Medium => _mm_prefetch(ptr.cast(), _MM_HINT_T1),
139                PrefetchLocality::High => _mm_prefetch(ptr.cast(), _MM_HINT_T0),
140            }
141        }
142    }
143
144    #[cfg(target_arch = "aarch64")]
145    {
146        // ARM NEON prefetch using inline assembly
147        // PRFM instruction with PLDL1KEEP, PLDL2KEEP, PLDL3KEEP
148        unsafe {
149            match locality {
150                PrefetchLocality::NonTemporal | PrefetchLocality::Low => {
151                    core::arch::asm!(
152                        "prfm pldl3keep, [{0}]",
153                        in(reg) ptr,
154                        options(nostack, preserves_flags)
155                    );
156                }
157                PrefetchLocality::Medium => {
158                    core::arch::asm!(
159                        "prfm pldl2keep, [{0}]",
160                        in(reg) ptr,
161                        options(nostack, preserves_flags)
162                    );
163                }
164                PrefetchLocality::High => {
165                    core::arch::asm!(
166                        "prfm pldl1keep, [{0}]",
167                        in(reg) ptr,
168                        options(nostack, preserves_flags)
169                    );
170                }
171            }
172        }
173    }
174
175    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
176    {
177        // No prefetch support - silently ignore
178        let _ = (ptr, locality);
179    }
180}
181
182/// Prefetches data for writing.
183///
184/// This is a hint to the processor that data at the given address will be
185/// written in the near future. This is useful for write-allocate cache policies.
186///
187/// # Arguments
188/// * `ptr` - Pointer to the data to prefetch
189/// * `locality` - Hint about how often the data will be reused
190///
191/// # Safety
192/// The pointer doesn't need to be valid - invalid prefetches are simply ignored.
193#[inline]
194pub fn prefetch_write<T>(ptr: *mut T, locality: PrefetchLocality) {
195    #[cfg(target_arch = "x86_64")]
196    {
197        use core::arch::x86_64::*;
198        unsafe {
199            // x86 doesn't have separate prefetch-for-write in SSE/AVX
200            // PREFETCHW is available with 3DNow! or PREFETCHWT1 with newer CPUs
201            // Fall back to regular prefetch
202            match locality {
203                PrefetchLocality::NonTemporal => _mm_prefetch(ptr.cast(), _MM_HINT_NTA),
204                PrefetchLocality::Low => _mm_prefetch(ptr.cast(), _MM_HINT_T2),
205                PrefetchLocality::Medium => _mm_prefetch(ptr.cast(), _MM_HINT_T1),
206                PrefetchLocality::High => _mm_prefetch(ptr.cast(), _MM_HINT_T0),
207            }
208        }
209    }
210
211    #[cfg(target_arch = "aarch64")]
212    {
213        // ARM NEON prefetch for store using PSTL1KEEP, PSTL2KEEP, PSTL3KEEP
214        unsafe {
215            match locality {
216                PrefetchLocality::NonTemporal | PrefetchLocality::Low => {
217                    core::arch::asm!(
218                        "prfm pstl3keep, [{0}]",
219                        in(reg) ptr,
220                        options(nostack, preserves_flags)
221                    );
222                }
223                PrefetchLocality::Medium => {
224                    core::arch::asm!(
225                        "prfm pstl2keep, [{0}]",
226                        in(reg) ptr,
227                        options(nostack, preserves_flags)
228                    );
229                }
230                PrefetchLocality::High => {
231                    core::arch::asm!(
232                        "prfm pstl1keep, [{0}]",
233                        in(reg) ptr,
234                        options(nostack, preserves_flags)
235                    );
236                }
237            }
238        }
239    }
240
241    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
242    {
243        let _ = (ptr, locality);
244    }
245}
246
247/// Prefetches a range of memory for reading.
248///
249/// This prefetches multiple cache lines starting at `ptr` and covering
250/// `count` elements of type `T`.
251///
252/// # Arguments
253/// * `ptr` - Pointer to the start of the data
254/// * `count` - Number of elements to prefetch
255/// * `locality` - Hint about data reuse
256#[inline]
257pub fn prefetch_read_range<T>(ptr: *const T, count: usize, locality: PrefetchLocality) {
258    let bytes = count * size_of::<T>();
259    let num_lines = bytes.div_ceil(CACHE_LINE_SIZE);
260
261    for i in 0..num_lines {
262        let offset = i * CACHE_LINE_SIZE;
263        prefetch_read(unsafe { (ptr as *const u8).add(offset) }, locality);
264    }
265}
266
267/// Prefetches a range of memory for writing.
268#[inline]
269pub fn prefetch_write_range<T>(ptr: *mut T, count: usize, locality: PrefetchLocality) {
270    let bytes = count * size_of::<T>();
271    let num_lines = bytes.div_ceil(CACHE_LINE_SIZE);
272
273    for i in 0..num_lines {
274        let offset = i * CACHE_LINE_SIZE;
275        prefetch_write(unsafe { (ptr as *mut u8).add(offset) }, locality);
276    }
277}
278
279/// Prefetch distance calculator for streaming access patterns.
280///
281/// This calculates the optimal prefetch distance (in elements) based on
282/// memory bandwidth and latency estimates.
283#[derive(Debug, Clone, Copy)]
284pub struct PrefetchDistance {
285    /// Number of cache lines to prefetch ahead
286    pub lines_ahead: usize,
287}
288
289impl Default for PrefetchDistance {
290    fn default() -> Self {
291        // Default: prefetch 8 cache lines ahead
292        // This works well for most streaming workloads
293        Self { lines_ahead: 8 }
294    }
295}
296
297impl PrefetchDistance {
298    /// Creates a new prefetch distance calculator.
299    pub const fn new(lines_ahead: usize) -> Self {
300        Self { lines_ahead }
301    }
302
303    /// Calculates the prefetch offset in bytes.
304    #[inline]
305    pub const fn offset_bytes(&self) -> usize {
306        self.lines_ahead * CACHE_LINE_SIZE
307    }
308
309    /// Calculates the prefetch offset in elements.
310    #[inline]
311    pub const fn offset_elements<T>(&self) -> usize {
312        self.offset_bytes() / size_of::<T>()
313    }
314}
315
316/// Default alignment for SIMD operations (cache line size).
317///
318/// Apple Silicon (M1/M2/M3) uses 128-byte cache lines for optimal performance.
319/// x86_64 and most other architectures use 64-byte cache lines.
320#[cfg(target_arch = "aarch64")]
321pub const CACHE_LINE_SIZE: usize = 128;
322
323/// Cache line size for the target architecture.
324///
325/// x86_64 and most other architectures use 64-byte cache lines.
326#[cfg(not(target_arch = "aarch64"))]
327pub const CACHE_LINE_SIZE: usize = 64;
328
329/// Default alignment for matrices.
330///
331/// This is set to match the cache line size for optimal memory access patterns.
332/// - Apple Silicon (aarch64): 128 bytes
333/// - x86_64 and others: 64 bytes
334pub const DEFAULT_ALIGN: usize = CACHE_LINE_SIZE;
335
336/// Computes the aligned size for a given element count and type.
337#[inline]
338pub const fn aligned_size<T>(count: usize, align: usize) -> usize {
339    let size = count * size_of::<T>();
340    (size + align - 1) & !(align - 1)
341}
342
343/// Computes the number of elements that fit in a given number of bytes with alignment.
344#[inline]
345pub const fn elements_per_aligned_bytes<T>(bytes: usize) -> usize {
346    bytes / size_of::<T>()
347}
348
349/// Rounds up to the next multiple of a power of 2.
350#[inline]
351pub const fn round_up_pow2(value: usize, align: usize) -> usize {
352    debug_assert!(align.is_power_of_two());
353    (value + align - 1) & !(align - 1)
354}
355
356/// Checks if a pointer is aligned to the given alignment.
357#[inline]
358pub fn is_aligned<T>(ptr: *const T, align: usize) -> bool {
359    (ptr as usize) % align == 0
360}