oxiblas_matrix/
prefetch.rs

1//! Prefetch utilities for improved cache performance.
2//!
3//! This module provides prefetch hints that can improve performance for
4//! large matrix operations by bringing data into cache before it's needed.
5//!
6//! # Cache Hierarchy
7//!
8//! Modern CPUs have multiple cache levels:
9//! - L1 (fastest, smallest, ~32KB per core)
10//! - L2 (fast, medium, ~256KB-1MB per core)
11//! - L3 (slower, shared, ~8-32MB)
12//!
13//! # Usage
14//!
15//! Prefetching is most effective when:
16//! - Processing large matrices that don't fit in cache
17//! - Access patterns are predictable (sequential or strided)
18//! - There's enough distance between prefetch and use
19//!
20//! # Example
21//!
22//! ```ignore
23//! use oxiblas_matrix::prefetch::{prefetch_read, PrefetchLocality};
24//!
25//! // Prefetch data for upcoming reads
26//! for i in (0..n).step_by(64 / size_of::<f64>()) {
27//!     prefetch_read(&data[i + PREFETCH_DISTANCE], PrefetchLocality::Medium);
28//! }
29//! ```
30
31/// Cache locality hint for prefetch operations.
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum PrefetchLocality {
34    /// Non-temporal: Data will be used once and not reused.
35    /// Prefetches to L1 but may be evicted quickly.
36    NonTemporal,
37    /// Low: Data will be used a few times.
38    /// Typically prefetches to L3.
39    Low,
40    /// Medium: Data will be used moderately.
41    /// Typically prefetches to L2.
42    Medium,
43    /// High: Data will be heavily reused.
44    /// Prefetches to L1 for fastest access.
45    High,
46}
47
48/// Prefetch data for reading.
49///
50/// Issues a prefetch hint to bring the cache line containing `ptr` into cache.
51/// This is a hint and may be ignored by the CPU.
52///
53/// # Safety
54///
55/// The pointer must be valid for at least one byte, but doesn't need to be
56/// aligned to a cache line. The CPU will prefetch the entire cache line
57/// containing the address.
58#[inline]
59pub fn prefetch_read<T>(ptr: *const T, locality: PrefetchLocality) {
60    #[cfg(target_arch = "x86_64")]
61    {
62        use core::arch::x86_64::*;
63        unsafe {
64            match locality {
65                PrefetchLocality::NonTemporal => _mm_prefetch(ptr as *const i8, _MM_HINT_NTA),
66                PrefetchLocality::Low => _mm_prefetch(ptr as *const i8, _MM_HINT_T2),
67                PrefetchLocality::Medium => _mm_prefetch(ptr as *const i8, _MM_HINT_T1),
68                PrefetchLocality::High => _mm_prefetch(ptr as *const i8, _MM_HINT_T0),
69            }
70        }
71    }
72
73    #[cfg(target_arch = "aarch64")]
74    {
75        // ARM NEON prefetch using inline assembly
76        // PRFM instruction with PLDL1KEEP, PLDL2KEEP, PLDL3KEEP
77        unsafe {
78            match locality {
79                PrefetchLocality::NonTemporal | PrefetchLocality::Low => {
80                    core::arch::asm!(
81                        "prfm pldl3keep, [{0}]",
82                        in(reg) ptr,
83                        options(nostack, preserves_flags)
84                    );
85                }
86                PrefetchLocality::Medium => {
87                    core::arch::asm!(
88                        "prfm pldl2keep, [{0}]",
89                        in(reg) ptr,
90                        options(nostack, preserves_flags)
91                    );
92                }
93                PrefetchLocality::High => {
94                    core::arch::asm!(
95                        "prfm pldl1keep, [{0}]",
96                        in(reg) ptr,
97                        options(nostack, preserves_flags)
98                    );
99                }
100            }
101        }
102    }
103
104    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
105    {
106        // No prefetch available - ignore
107        let _ = (ptr, locality);
108    }
109}
110
111/// Prefetch data for writing.
112///
113/// Similar to `prefetch_read` but hints that the data will be written.
114/// This can help avoid read-for-ownership overhead on some architectures.
115#[inline]
116pub fn prefetch_write<T>(ptr: *mut T, locality: PrefetchLocality) {
117    #[cfg(target_arch = "x86_64")]
118    {
119        use core::arch::x86_64::*;
120        // x86 prefetchw is not always available, use prefetcht0 as fallback
121        unsafe {
122            match locality {
123                PrefetchLocality::NonTemporal => _mm_prefetch(ptr as *const i8, _MM_HINT_NTA),
124                PrefetchLocality::Low => _mm_prefetch(ptr as *const i8, _MM_HINT_T2),
125                PrefetchLocality::Medium => _mm_prefetch(ptr as *const i8, _MM_HINT_T1),
126                PrefetchLocality::High => _mm_prefetch(ptr as *const i8, _MM_HINT_T0),
127            }
128        }
129    }
130
131    #[cfg(target_arch = "aarch64")]
132    {
133        // ARM NEON prefetch for store using PSTL1KEEP, PSTL2KEEP, PSTL3KEEP
134        unsafe {
135            match locality {
136                PrefetchLocality::NonTemporal | PrefetchLocality::Low => {
137                    core::arch::asm!(
138                        "prfm pstl3keep, [{0}]",
139                        in(reg) ptr,
140                        options(nostack, preserves_flags)
141                    );
142                }
143                PrefetchLocality::Medium => {
144                    core::arch::asm!(
145                        "prfm pstl2keep, [{0}]",
146                        in(reg) ptr,
147                        options(nostack, preserves_flags)
148                    );
149                }
150                PrefetchLocality::High => {
151                    core::arch::asm!(
152                        "prfm pstl1keep, [{0}]",
153                        in(reg) ptr,
154                        options(nostack, preserves_flags)
155                    );
156                }
157            }
158        }
159    }
160
161    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
162    {
163        let _ = (ptr, locality);
164    }
165}
166
167/// Cache line size in bytes (typical for modern CPUs).
168pub const CACHE_LINE_SIZE: usize = 64;
169
170/// Suggested prefetch distance in cache lines for sequential access.
171///
172/// This is the number of cache lines ahead to prefetch. The optimal value
173/// depends on memory latency and processing speed.
174pub const PREFETCH_DISTANCE_LINES: usize = 8;
175
176/// Suggested prefetch distance in bytes for sequential access.
177pub const PREFETCH_DISTANCE_BYTES: usize = PREFETCH_DISTANCE_LINES * CACHE_LINE_SIZE;
178
179/// Prefetch a range of memory for reading.
180///
181/// Prefetches cache lines covering the range `[ptr, ptr + len)`.
182/// Useful for preparing a contiguous block of data.
183#[inline]
184pub fn prefetch_range_read<T>(ptr: *const T, len: usize, locality: PrefetchLocality) {
185    if len == 0 {
186        return;
187    }
188
189    let elem_size = core::mem::size_of::<T>();
190    let byte_len = len * elem_size;
191    let num_lines = byte_len.div_ceil(CACHE_LINE_SIZE);
192
193    for i in 0..num_lines {
194        let offset = i * CACHE_LINE_SIZE;
195        let addr = unsafe { (ptr as *const u8).add(offset) };
196        prefetch_read(addr, locality);
197    }
198}
199
200/// Prefetch a range of memory for writing.
201#[inline]
202pub fn prefetch_range_write<T>(ptr: *mut T, len: usize, locality: PrefetchLocality) {
203    if len == 0 {
204        return;
205    }
206
207    let elem_size = core::mem::size_of::<T>();
208    let byte_len = len * elem_size;
209    let num_lines = byte_len.div_ceil(CACHE_LINE_SIZE);
210
211    for i in 0..num_lines {
212        let offset = i * CACHE_LINE_SIZE;
213        let addr = unsafe { (ptr as *mut u8).add(offset) };
214        prefetch_write(addr, locality);
215    }
216}
217
218/// Prefetch a column of a matrix for reading.
219///
220/// For column-major storage, this prefetches contiguous memory.
221/// For row-major or strided access, this prefetches with the given stride.
222#[inline]
223pub fn prefetch_column<T>(
224    ptr: *const T,
225    nrows: usize,
226    row_stride: usize,
227    locality: PrefetchLocality,
228) {
229    let elem_size = core::mem::size_of::<T>();
230
231    // If contiguous (row_stride == nrows for column-major), prefetch as range
232    if row_stride == 1 || (row_stride * elem_size) <= CACHE_LINE_SIZE {
233        prefetch_range_read(ptr, nrows, locality);
234    } else {
235        // Strided access - prefetch individual cache lines
236        // Only prefetch if the stride is large enough to warrant it
237        let lines_per_column = (nrows * elem_size).div_ceil(CACHE_LINE_SIZE);
238        for i in 0..lines_per_column.min(nrows) {
239            let row = i * (CACHE_LINE_SIZE / elem_size).max(1);
240            if row < nrows {
241                let addr = unsafe { ptr.add(row * row_stride) };
242                prefetch_read(addr, locality);
243            }
244        }
245    }
246}
247
248/// Prefetch a block of a matrix for reading.
249///
250/// Prefetches a rectangular block starting at `ptr` with dimensions
251/// `block_rows × block_cols`.
252#[inline]
253pub fn prefetch_block<T>(
254    ptr: *const T,
255    block_rows: usize,
256    block_cols: usize,
257    row_stride: usize,
258    locality: PrefetchLocality,
259) {
260    for j in 0..block_cols {
261        let col_ptr = unsafe { ptr.add(j * row_stride) };
262        prefetch_column(col_ptr, block_rows, 1, locality);
263    }
264}
265
266/// Prefetch hint for matrix operations.
267///
268/// This struct provides a convenient interface for prefetching during
269/// matrix operations with predictable access patterns.
270pub struct MatrixPrefetcher<T> {
271    /// Base pointer.
272    ptr: *const T,
273    /// Number of rows.
274    nrows: usize,
275    /// Number of columns.
276    ncols: usize,
277    /// Row stride.
278    row_stride: usize,
279    /// Current prefetch column.
280    current_col: usize,
281    /// Prefetch distance in columns.
282    distance: usize,
283    /// Locality hint.
284    locality: PrefetchLocality,
285}
286
287impl<T> MatrixPrefetcher<T> {
288    /// Creates a new matrix prefetcher.
289    ///
290    /// # Parameters
291    /// - `ptr`: Pointer to matrix data
292    /// - `nrows`: Number of rows
293    /// - `ncols`: Number of columns
294    /// - `row_stride`: Stride between rows (leading dimension)
295    /// - `distance`: Number of columns to prefetch ahead
296    /// - `locality`: Cache locality hint
297    #[inline]
298    pub fn new(
299        ptr: *const T,
300        nrows: usize,
301        ncols: usize,
302        row_stride: usize,
303        distance: usize,
304        locality: PrefetchLocality,
305    ) -> Self {
306        let prefetcher = MatrixPrefetcher {
307            ptr,
308            nrows,
309            ncols,
310            row_stride,
311            current_col: 0,
312            distance,
313            locality,
314        };
315
316        // Prefetch initial columns
317        for j in 0..distance.min(ncols) {
318            let col_ptr = unsafe { ptr.add(j * row_stride) };
319            prefetch_column(col_ptr, nrows, 1, locality);
320        }
321
322        prefetcher
323    }
324
325    /// Advance to the next column and prefetch ahead.
326    ///
327    /// Call this as you process each column to keep data prefetched.
328    #[inline]
329    pub fn advance(&mut self) {
330        self.current_col += 1;
331
332        let prefetch_col = self.current_col + self.distance;
333        if prefetch_col < self.ncols {
334            let col_ptr = unsafe { self.ptr.add(prefetch_col * self.row_stride) };
335            prefetch_column(col_ptr, self.nrows, 1, self.locality);
336        }
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343
344    #[test]
345    fn test_prefetch_locality() {
346        assert_ne!(PrefetchLocality::High, PrefetchLocality::Low);
347        assert_eq!(PrefetchLocality::Medium, PrefetchLocality::Medium);
348    }
349
350    // Prefetch tests use inline assembly which miri doesn't support
351    #[test]
352    #[cfg_attr(miri, ignore)]
353    fn test_prefetch_read_safety() {
354        // Prefetching should not crash even with unusual inputs
355        let data = [1.0f64; 1024];
356
357        prefetch_read(data.as_ptr(), PrefetchLocality::High);
358        prefetch_read(data.as_ptr().wrapping_add(100), PrefetchLocality::Medium);
359        prefetch_read(data.as_ptr().wrapping_add(500), PrefetchLocality::Low);
360        prefetch_read(
361            data.as_ptr().wrapping_add(900),
362            PrefetchLocality::NonTemporal,
363        );
364    }
365
366    #[test]
367    #[cfg_attr(miri, ignore)]
368    fn test_prefetch_write_safety() {
369        let mut data = [1.0f64; 1024];
370
371        prefetch_write(data.as_mut_ptr(), PrefetchLocality::High);
372        prefetch_write(
373            data.as_mut_ptr().wrapping_add(100),
374            PrefetchLocality::Medium,
375        );
376    }
377
378    #[test]
379    #[cfg_attr(miri, ignore)]
380    fn test_prefetch_range() {
381        let data = vec![1.0f64; 4096];
382
383        // Should not crash
384        prefetch_range_read(data.as_ptr(), data.len(), PrefetchLocality::Medium);
385        prefetch_range_read(data.as_ptr(), 0, PrefetchLocality::High); // Empty range
386        prefetch_range_read(data.as_ptr(), 1, PrefetchLocality::Low); // Single element
387    }
388
389    #[test]
390    #[cfg_attr(miri, ignore)]
391    fn test_prefetch_column() {
392        let data = vec![1.0f64; 1000];
393
394        // Contiguous column
395        prefetch_column(data.as_ptr(), 100, 1, PrefetchLocality::High);
396
397        // Strided column (simulating row-major access)
398        prefetch_column(data.as_ptr(), 10, 100, PrefetchLocality::Medium);
399    }
400
401    #[test]
402    #[cfg_attr(miri, ignore)]
403    fn test_prefetch_block() {
404        let data = vec![1.0f64; 10000];
405
406        // Prefetch a 64x64 block with stride 100
407        prefetch_block(data.as_ptr(), 64, 64, 100, PrefetchLocality::High);
408    }
409
410    #[test]
411    #[cfg_attr(miri, ignore)]
412    fn test_matrix_prefetcher() {
413        let data = vec![1.0f64; 10000];
414
415        let mut prefetcher = MatrixPrefetcher::new(
416            data.as_ptr(),
417            100, // nrows
418            100, // ncols
419            100, // row_stride
420            8,   // distance
421            PrefetchLocality::Medium,
422        );
423
424        // Simulate processing columns
425        for _ in 0..100 {
426            prefetcher.advance();
427        }
428    }
429
430    #[test]
431    fn test_cache_constants() {
432        assert_eq!(CACHE_LINE_SIZE, 64);
433        const { assert!(PREFETCH_DISTANCE_LINES > 0) };
434        assert_eq!(
435            PREFETCH_DISTANCE_BYTES,
436            PREFETCH_DISTANCE_LINES * CACHE_LINE_SIZE
437        );
438    }
439}