1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum PrefetchLocality {
34 NonTemporal,
37 Low,
40 Medium,
43 High,
46}
47
48#[inline]
59pub fn prefetch_read<T>(ptr: *const T, locality: PrefetchLocality) {
60 #[cfg(target_arch = "x86_64")]
61 {
62 use core::arch::x86_64::*;
63 unsafe {
64 match locality {
65 PrefetchLocality::NonTemporal => _mm_prefetch(ptr as *const i8, _MM_HINT_NTA),
66 PrefetchLocality::Low => _mm_prefetch(ptr as *const i8, _MM_HINT_T2),
67 PrefetchLocality::Medium => _mm_prefetch(ptr as *const i8, _MM_HINT_T1),
68 PrefetchLocality::High => _mm_prefetch(ptr as *const i8, _MM_HINT_T0),
69 }
70 }
71 }
72
73 #[cfg(target_arch = "aarch64")]
74 {
75 unsafe {
78 match locality {
79 PrefetchLocality::NonTemporal | PrefetchLocality::Low => {
80 core::arch::asm!(
81 "prfm pldl3keep, [{0}]",
82 in(reg) ptr,
83 options(nostack, preserves_flags)
84 );
85 }
86 PrefetchLocality::Medium => {
87 core::arch::asm!(
88 "prfm pldl2keep, [{0}]",
89 in(reg) ptr,
90 options(nostack, preserves_flags)
91 );
92 }
93 PrefetchLocality::High => {
94 core::arch::asm!(
95 "prfm pldl1keep, [{0}]",
96 in(reg) ptr,
97 options(nostack, preserves_flags)
98 );
99 }
100 }
101 }
102 }
103
104 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
105 {
106 let _ = (ptr, locality);
108 }
109}
110
111#[inline]
116pub fn prefetch_write<T>(ptr: *mut T, locality: PrefetchLocality) {
117 #[cfg(target_arch = "x86_64")]
118 {
119 use core::arch::x86_64::*;
120 unsafe {
122 match locality {
123 PrefetchLocality::NonTemporal => _mm_prefetch(ptr as *const i8, _MM_HINT_NTA),
124 PrefetchLocality::Low => _mm_prefetch(ptr as *const i8, _MM_HINT_T2),
125 PrefetchLocality::Medium => _mm_prefetch(ptr as *const i8, _MM_HINT_T1),
126 PrefetchLocality::High => _mm_prefetch(ptr as *const i8, _MM_HINT_T0),
127 }
128 }
129 }
130
131 #[cfg(target_arch = "aarch64")]
132 {
133 unsafe {
135 match locality {
136 PrefetchLocality::NonTemporal | PrefetchLocality::Low => {
137 core::arch::asm!(
138 "prfm pstl3keep, [{0}]",
139 in(reg) ptr,
140 options(nostack, preserves_flags)
141 );
142 }
143 PrefetchLocality::Medium => {
144 core::arch::asm!(
145 "prfm pstl2keep, [{0}]",
146 in(reg) ptr,
147 options(nostack, preserves_flags)
148 );
149 }
150 PrefetchLocality::High => {
151 core::arch::asm!(
152 "prfm pstl1keep, [{0}]",
153 in(reg) ptr,
154 options(nostack, preserves_flags)
155 );
156 }
157 }
158 }
159 }
160
161 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
162 {
163 let _ = (ptr, locality);
164 }
165}
166
167pub const CACHE_LINE_SIZE: usize = 64;
169
170pub const PREFETCH_DISTANCE_LINES: usize = 8;
175
176pub const PREFETCH_DISTANCE_BYTES: usize = PREFETCH_DISTANCE_LINES * CACHE_LINE_SIZE;
178
179#[inline]
184pub fn prefetch_range_read<T>(ptr: *const T, len: usize, locality: PrefetchLocality) {
185 if len == 0 {
186 return;
187 }
188
189 let elem_size = core::mem::size_of::<T>();
190 let byte_len = len * elem_size;
191 let num_lines = byte_len.div_ceil(CACHE_LINE_SIZE);
192
193 for i in 0..num_lines {
194 let offset = i * CACHE_LINE_SIZE;
195 let addr = unsafe { (ptr as *const u8).add(offset) };
196 prefetch_read(addr, locality);
197 }
198}
199
200#[inline]
202pub fn prefetch_range_write<T>(ptr: *mut T, len: usize, locality: PrefetchLocality) {
203 if len == 0 {
204 return;
205 }
206
207 let elem_size = core::mem::size_of::<T>();
208 let byte_len = len * elem_size;
209 let num_lines = byte_len.div_ceil(CACHE_LINE_SIZE);
210
211 for i in 0..num_lines {
212 let offset = i * CACHE_LINE_SIZE;
213 let addr = unsafe { (ptr as *mut u8).add(offset) };
214 prefetch_write(addr, locality);
215 }
216}
217
218#[inline]
223pub fn prefetch_column<T>(
224 ptr: *const T,
225 nrows: usize,
226 row_stride: usize,
227 locality: PrefetchLocality,
228) {
229 let elem_size = core::mem::size_of::<T>();
230
231 if row_stride == 1 || (row_stride * elem_size) <= CACHE_LINE_SIZE {
233 prefetch_range_read(ptr, nrows, locality);
234 } else {
235 let lines_per_column = (nrows * elem_size).div_ceil(CACHE_LINE_SIZE);
238 for i in 0..lines_per_column.min(nrows) {
239 let row = i * (CACHE_LINE_SIZE / elem_size).max(1);
240 if row < nrows {
241 let addr = unsafe { ptr.add(row * row_stride) };
242 prefetch_read(addr, locality);
243 }
244 }
245 }
246}
247
248#[inline]
253pub fn prefetch_block<T>(
254 ptr: *const T,
255 block_rows: usize,
256 block_cols: usize,
257 row_stride: usize,
258 locality: PrefetchLocality,
259) {
260 for j in 0..block_cols {
261 let col_ptr = unsafe { ptr.add(j * row_stride) };
262 prefetch_column(col_ptr, block_rows, 1, locality);
263 }
264}
265
266pub struct MatrixPrefetcher<T> {
271 ptr: *const T,
273 nrows: usize,
275 ncols: usize,
277 row_stride: usize,
279 current_col: usize,
281 distance: usize,
283 locality: PrefetchLocality,
285}
286
287impl<T> MatrixPrefetcher<T> {
288 #[inline]
298 pub fn new(
299 ptr: *const T,
300 nrows: usize,
301 ncols: usize,
302 row_stride: usize,
303 distance: usize,
304 locality: PrefetchLocality,
305 ) -> Self {
306 let prefetcher = MatrixPrefetcher {
307 ptr,
308 nrows,
309 ncols,
310 row_stride,
311 current_col: 0,
312 distance,
313 locality,
314 };
315
316 for j in 0..distance.min(ncols) {
318 let col_ptr = unsafe { ptr.add(j * row_stride) };
319 prefetch_column(col_ptr, nrows, 1, locality);
320 }
321
322 prefetcher
323 }
324
325 #[inline]
329 pub fn advance(&mut self) {
330 self.current_col += 1;
331
332 let prefetch_col = self.current_col + self.distance;
333 if prefetch_col < self.ncols {
334 let col_ptr = unsafe { self.ptr.add(prefetch_col * self.row_stride) };
335 prefetch_column(col_ptr, self.nrows, 1, self.locality);
336 }
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343
344 #[test]
345 fn test_prefetch_locality() {
346 assert_ne!(PrefetchLocality::High, PrefetchLocality::Low);
347 assert_eq!(PrefetchLocality::Medium, PrefetchLocality::Medium);
348 }
349
350 #[test]
352 #[cfg_attr(miri, ignore)]
353 fn test_prefetch_read_safety() {
354 let data = [1.0f64; 1024];
356
357 prefetch_read(data.as_ptr(), PrefetchLocality::High);
358 prefetch_read(data.as_ptr().wrapping_add(100), PrefetchLocality::Medium);
359 prefetch_read(data.as_ptr().wrapping_add(500), PrefetchLocality::Low);
360 prefetch_read(
361 data.as_ptr().wrapping_add(900),
362 PrefetchLocality::NonTemporal,
363 );
364 }
365
366 #[test]
367 #[cfg_attr(miri, ignore)]
368 fn test_prefetch_write_safety() {
369 let mut data = [1.0f64; 1024];
370
371 prefetch_write(data.as_mut_ptr(), PrefetchLocality::High);
372 prefetch_write(
373 data.as_mut_ptr().wrapping_add(100),
374 PrefetchLocality::Medium,
375 );
376 }
377
378 #[test]
379 #[cfg_attr(miri, ignore)]
380 fn test_prefetch_range() {
381 let data = vec![1.0f64; 4096];
382
383 prefetch_range_read(data.as_ptr(), data.len(), PrefetchLocality::Medium);
385 prefetch_range_read(data.as_ptr(), 0, PrefetchLocality::High); prefetch_range_read(data.as_ptr(), 1, PrefetchLocality::Low); }
388
389 #[test]
390 #[cfg_attr(miri, ignore)]
391 fn test_prefetch_column() {
392 let data = vec![1.0f64; 1000];
393
394 prefetch_column(data.as_ptr(), 100, 1, PrefetchLocality::High);
396
397 prefetch_column(data.as_ptr(), 10, 100, PrefetchLocality::Medium);
399 }
400
401 #[test]
402 #[cfg_attr(miri, ignore)]
403 fn test_prefetch_block() {
404 let data = vec![1.0f64; 10000];
405
406 prefetch_block(data.as_ptr(), 64, 64, 100, PrefetchLocality::High);
408 }
409
410 #[test]
411 #[cfg_attr(miri, ignore)]
412 fn test_matrix_prefetcher() {
413 let data = vec![1.0f64; 10000];
414
415 let mut prefetcher = MatrixPrefetcher::new(
416 data.as_ptr(),
417 100, 100, 100, 8, PrefetchLocality::Medium,
422 );
423
424 for _ in 0..100 {
426 prefetcher.advance();
427 }
428 }
429
430 #[test]
431 fn test_cache_constants() {
432 assert_eq!(CACHE_LINE_SIZE, 64);
433 const { assert!(PREFETCH_DISTANCE_LINES > 0) };
434 assert_eq!(
435 PREFETCH_DISTANCE_BYTES,
436 PREFETCH_DISTANCE_LINES * CACHE_LINE_SIZE
437 );
438 }
439}