1use crate::wide_utils::{FromBitmask, WideUtilsExt};
33#[cfg(target_arch = "aarch64")]
34use std::arch::asm;
35
36pub trait CacheLevel {
38 const HINT: i32;
40}
41
42pub struct NTA;
44impl CacheLevel for NTA {
45 #[cfg(target_arch = "x86_64")]
46 const HINT: i32 = 0; #[cfg(target_arch = "aarch64")]
49 const HINT: i32 = 0; }
51
52pub struct L1;
54impl CacheLevel for L1 {
55 #[cfg(target_arch = "x86_64")]
56 const HINT: i32 = 3; #[cfg(target_arch = "aarch64")]
59 const HINT: i32 = 1; }
61
62pub struct L2;
64impl CacheLevel for L2 {
65 #[cfg(target_arch = "x86_64")]
66 const HINT: i32 = 2; #[cfg(target_arch = "aarch64")]
69 const HINT: i32 = 2; }
71
72pub struct L3;
74impl CacheLevel for L3 {
75 #[cfg(target_arch = "x86_64")]
76 const HINT: i32 = 1; #[cfg(target_arch = "aarch64")]
79 const HINT: i32 = 3; }
81
82#[inline(always)]
84pub fn prefetch_address<T, L: CacheLevel>(base: &T, offset: u32) {
85 let ptr = unsafe { (base as *const T).add(offset as usize) as *const i8 };
86
87 #[cfg(target_arch = "x86_64")]
88 {
89 use std::arch::x86_64::*;
90 unsafe {
91 match L::HINT {
92 0 => _mm_prefetch(ptr, _MM_HINT_NTA),
93 1 => _mm_prefetch(ptr, _MM_HINT_T2),
94 2 => _mm_prefetch(ptr, _MM_HINT_T1),
95 3 => _mm_prefetch(ptr, _MM_HINT_T0),
96 _ => _mm_prefetch(ptr, _MM_HINT_T0),
97 }
98 }
99 }
100
101 #[cfg(target_arch = "aarch64")]
102 {
103 unsafe {
104 match L::HINT {
105 0 => asm!("prfm pldl1strm, [{0}]", in(reg) ptr), 1 => asm!("prfm pldl1keep, [{0}]", in(reg) ptr), 2 => asm!("prfm pldl2keep, [{0}]", in(reg) ptr), 3 => asm!("prfm pldl3keep, [{0}]", in(reg) ptr), _ => asm!("prfm pldl1keep, [{0}]", in(reg) ptr), }
111 }
112 }
113
114 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
115 {
116 let _ = ptr;
118 }
119}
120
121#[inline(always)]
135pub fn prefetch_eight_offsets<T, L: CacheLevel>(base: &T, offsets: &[u32; 8]) {
136 let base_ptr = base as *const T;
137
138 #[cfg(target_arch = "x86_64")]
139 {
140 use std::arch::x86_64::*;
141 unsafe {
142 let ptrs = [
144 (base_ptr.add(offsets[0] as usize) as *const i8),
145 (base_ptr.add(offsets[1] as usize) as *const i8),
146 (base_ptr.add(offsets[2] as usize) as *const i8),
147 (base_ptr.add(offsets[3] as usize) as *const i8),
148 (base_ptr.add(offsets[4] as usize) as *const i8),
149 (base_ptr.add(offsets[5] as usize) as *const i8),
150 (base_ptr.add(offsets[6] as usize) as *const i8),
151 (base_ptr.add(offsets[7] as usize) as *const i8),
152 ];
153
154 match L::HINT {
156 0 => {
157 _mm_prefetch(ptrs[0], _MM_HINT_NTA);
158 _mm_prefetch(ptrs[1], _MM_HINT_NTA);
159 _mm_prefetch(ptrs[2], _MM_HINT_NTA);
160 _mm_prefetch(ptrs[3], _MM_HINT_NTA);
161 _mm_prefetch(ptrs[4], _MM_HINT_NTA);
162 _mm_prefetch(ptrs[5], _MM_HINT_NTA);
163 _mm_prefetch(ptrs[6], _MM_HINT_NTA);
164 _mm_prefetch(ptrs[7], _MM_HINT_NTA);
165 }
166 1 => {
167 _mm_prefetch(ptrs[0], _MM_HINT_T2);
168 _mm_prefetch(ptrs[1], _MM_HINT_T2);
169 _mm_prefetch(ptrs[2], _MM_HINT_T2);
170 _mm_prefetch(ptrs[3], _MM_HINT_T2);
171 _mm_prefetch(ptrs[4], _MM_HINT_T2);
172 _mm_prefetch(ptrs[5], _MM_HINT_T2);
173 _mm_prefetch(ptrs[6], _MM_HINT_T2);
174 _mm_prefetch(ptrs[7], _MM_HINT_T2);
175 }
176 2 => {
177 _mm_prefetch(ptrs[0], _MM_HINT_T1);
178 _mm_prefetch(ptrs[1], _MM_HINT_T1);
179 _mm_prefetch(ptrs[2], _MM_HINT_T1);
180 _mm_prefetch(ptrs[3], _MM_HINT_T1);
181 _mm_prefetch(ptrs[4], _MM_HINT_T1);
182 _mm_prefetch(ptrs[5], _MM_HINT_T1);
183 _mm_prefetch(ptrs[6], _MM_HINT_T1);
184 _mm_prefetch(ptrs[7], _MM_HINT_T1);
185 }
186 3 => {
187 _mm_prefetch(ptrs[0], _MM_HINT_T0);
188 _mm_prefetch(ptrs[1], _MM_HINT_T0);
189 _mm_prefetch(ptrs[2], _MM_HINT_T0);
190 _mm_prefetch(ptrs[3], _MM_HINT_T0);
191 _mm_prefetch(ptrs[4], _MM_HINT_T0);
192 _mm_prefetch(ptrs[5], _MM_HINT_T0);
193 _mm_prefetch(ptrs[6], _MM_HINT_T0);
194 _mm_prefetch(ptrs[7], _MM_HINT_T0);
195 }
196 _ => {
197 _mm_prefetch(ptrs[0], _MM_HINT_T0);
198 _mm_prefetch(ptrs[1], _MM_HINT_T0);
199 _mm_prefetch(ptrs[2], _MM_HINT_T0);
200 _mm_prefetch(ptrs[3], _MM_HINT_T0);
201 _mm_prefetch(ptrs[4], _MM_HINT_T0);
202 _mm_prefetch(ptrs[5], _MM_HINT_T0);
203 _mm_prefetch(ptrs[6], _MM_HINT_T0);
204 _mm_prefetch(ptrs[7], _MM_HINT_T0);
205 }
206 }
207 }
208 }
209
210 #[cfg(target_arch = "aarch64")]
211 {
212 unsafe {
213 let addrs = [
215 base_ptr.add(offsets[0] as usize) as *const u8,
216 base_ptr.add(offsets[1] as usize) as *const u8,
217 base_ptr.add(offsets[2] as usize) as *const u8,
218 base_ptr.add(offsets[3] as usize) as *const u8,
219 base_ptr.add(offsets[4] as usize) as *const u8,
220 base_ptr.add(offsets[5] as usize) as *const u8,
221 base_ptr.add(offsets[6] as usize) as *const u8,
222 base_ptr.add(offsets[7] as usize) as *const u8,
223 ];
224
225 match L::HINT {
227 0 => {
228 asm!("prfm pldl1strm, [{0}]", in(reg) addrs[0]);
230 asm!("prfm pldl1strm, [{0}]", in(reg) addrs[1]);
231 asm!("prfm pldl1strm, [{0}]", in(reg) addrs[2]);
232 asm!("prfm pldl1strm, [{0}]", in(reg) addrs[3]);
233 asm!("prfm pldl1strm, [{0}]", in(reg) addrs[4]);
234 asm!("prfm pldl1strm, [{0}]", in(reg) addrs[5]);
235 asm!("prfm pldl1strm, [{0}]", in(reg) addrs[6]);
236 asm!("prfm pldl1strm, [{0}]", in(reg) addrs[7]);
237 }
238 1 => {
239 asm!("prfm pldl1keep, [{0}]", in(reg) addrs[0]);
241 asm!("prfm pldl1keep, [{0}]", in(reg) addrs[1]);
242 asm!("prfm pldl1keep, [{0}]", in(reg) addrs[2]);
243 asm!("prfm pldl1keep, [{0}]", in(reg) addrs[3]);
244 asm!("prfm pldl1keep, [{0}]", in(reg) addrs[4]);
245 asm!("prfm pldl1keep, [{0}]", in(reg) addrs[5]);
246 asm!("prfm pldl1keep, [{0}]", in(reg) addrs[6]);
247 asm!("prfm pldl1keep, [{0}]", in(reg) addrs[7]);
248 }
249 2 => {
250 asm!("prfm pldl2keep, [{0}]", in(reg) addrs[0]);
252 asm!("prfm pldl2keep, [{0}]", in(reg) addrs[1]);
253 asm!("prfm pldl2keep, [{0}]", in(reg) addrs[2]);
254 asm!("prfm pldl2keep, [{0}]", in(reg) addrs[3]);
255 asm!("prfm pldl2keep, [{0}]", in(reg) addrs[4]);
256 asm!("prfm pldl2keep, [{0}]", in(reg) addrs[5]);
257 asm!("prfm pldl2keep, [{0}]", in(reg) addrs[6]);
258 asm!("prfm pldl2keep, [{0}]", in(reg) addrs[7]);
259 }
260 3 => {
261 asm!("prfm pldl3keep, [{0}]", in(reg) addrs[0]);
263 asm!("prfm pldl3keep, [{0}]", in(reg) addrs[1]);
264 asm!("prfm pldl3keep, [{0}]", in(reg) addrs[2]);
265 asm!("prfm pldl3keep, [{0}]", in(reg) addrs[3]);
266 asm!("prfm pldl3keep, [{0}]", in(reg) addrs[4]);
267 asm!("prfm pldl3keep, [{0}]", in(reg) addrs[5]);
268 asm!("prfm pldl3keep, [{0}]", in(reg) addrs[6]);
269 asm!("prfm pldl3keep, [{0}]", in(reg) addrs[7]);
270 }
271 _ => {
272 asm!("prfm pldl1keep, [{0}]", in(reg) addrs[0]);
274 asm!("prfm pldl1keep, [{0}]", in(reg) addrs[1]);
275 asm!("prfm pldl1keep, [{0}]", in(reg) addrs[2]);
276 asm!("prfm pldl1keep, [{0}]", in(reg) addrs[3]);
277 asm!("prfm pldl1keep, [{0}]", in(reg) addrs[4]);
278 asm!("prfm pldl1keep, [{0}]", in(reg) addrs[5]);
279 asm!("prfm pldl1keep, [{0}]", in(reg) addrs[6]);
280 asm!("prfm pldl1keep, [{0}]", in(reg) addrs[7]);
281 }
282 }
283 }
284 }
285
286 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
287 {
288 let _ = (base_ptr, offsets);
290 }
291}
292
293#[inline(always)]
318pub fn prefetch_eight_masked<T, L: CacheLevel>(base: &T, offsets: [u32; 8], mask: u8) {
319 let base_ptr = base as *const T;
320
321 let base_addr = base_ptr as u64;
323 let base_simd = wide::u64x8::splat(base_addr);
324
325 let offsets_u32_simd = wide::u32x8::from(offsets) * (std::mem::size_of::<T>() as u32);
327 let zero_offsets_simd = wide::u32x8::splat(0);
328 let mask_simd = wide::u32x8::from_bitmask(mask);
329 let blended_offsets_simd = mask_simd.blend(offsets_u32_simd, zero_offsets_simd);
331
332 let offsets_u64_simd = blended_offsets_simd.widen_to_u64x8();
334 let selected_addrs_simd = base_simd + offsets_u64_simd;
335
336 let selected_addrs = selected_addrs_simd.to_array();
337
338 #[cfg(target_arch = "x86_64")]
339 {
340 use std::arch::x86_64::*;
341 unsafe {
342 let ptrs = [
344 selected_addrs[0] as *const i8,
345 selected_addrs[1] as *const i8,
346 selected_addrs[2] as *const i8,
347 selected_addrs[3] as *const i8,
348 selected_addrs[4] as *const i8,
349 selected_addrs[5] as *const i8,
350 selected_addrs[6] as *const i8,
351 selected_addrs[7] as *const i8,
352 ];
353
354 match L::HINT {
356 0 => {
357 _mm_prefetch(ptrs[0], _MM_HINT_NTA);
358 _mm_prefetch(ptrs[1], _MM_HINT_NTA);
359 _mm_prefetch(ptrs[2], _MM_HINT_NTA);
360 _mm_prefetch(ptrs[3], _MM_HINT_NTA);
361 _mm_prefetch(ptrs[4], _MM_HINT_NTA);
362 _mm_prefetch(ptrs[5], _MM_HINT_NTA);
363 _mm_prefetch(ptrs[6], _MM_HINT_NTA);
364 _mm_prefetch(ptrs[7], _MM_HINT_NTA);
365 }
366 1 => {
367 _mm_prefetch(ptrs[0], _MM_HINT_T2);
368 _mm_prefetch(ptrs[1], _MM_HINT_T2);
369 _mm_prefetch(ptrs[2], _MM_HINT_T2);
370 _mm_prefetch(ptrs[3], _MM_HINT_T2);
371 _mm_prefetch(ptrs[4], _MM_HINT_T2);
372 _mm_prefetch(ptrs[5], _MM_HINT_T2);
373 _mm_prefetch(ptrs[6], _MM_HINT_T2);
374 _mm_prefetch(ptrs[7], _MM_HINT_T2);
375 }
376 2 => {
377 _mm_prefetch(ptrs[0], _MM_HINT_T1);
378 _mm_prefetch(ptrs[1], _MM_HINT_T1);
379 _mm_prefetch(ptrs[2], _MM_HINT_T1);
380 _mm_prefetch(ptrs[3], _MM_HINT_T1);
381 _mm_prefetch(ptrs[4], _MM_HINT_T1);
382 _mm_prefetch(ptrs[5], _MM_HINT_T1);
383 _mm_prefetch(ptrs[6], _MM_HINT_T1);
384 _mm_prefetch(ptrs[7], _MM_HINT_T1);
385 }
386 3 => {
387 _mm_prefetch(ptrs[0], _MM_HINT_T0);
388 _mm_prefetch(ptrs[1], _MM_HINT_T0);
389 _mm_prefetch(ptrs[2], _MM_HINT_T0);
390 _mm_prefetch(ptrs[3], _MM_HINT_T0);
391 _mm_prefetch(ptrs[4], _MM_HINT_T0);
392 _mm_prefetch(ptrs[5], _MM_HINT_T0);
393 _mm_prefetch(ptrs[6], _MM_HINT_T0);
394 _mm_prefetch(ptrs[7], _MM_HINT_T0);
395 }
396 _ => {
397 _mm_prefetch(ptrs[0], _MM_HINT_T0);
398 _mm_prefetch(ptrs[1], _MM_HINT_T0);
399 _mm_prefetch(ptrs[2], _MM_HINT_T0);
400 _mm_prefetch(ptrs[3], _MM_HINT_T0);
401 _mm_prefetch(ptrs[4], _MM_HINT_T0);
402 _mm_prefetch(ptrs[5], _MM_HINT_T0);
403 _mm_prefetch(ptrs[6], _MM_HINT_T0);
404 _mm_prefetch(ptrs[7], _MM_HINT_T0);
405 }
406 }
407 }
408 }
409
410 #[cfg(target_arch = "aarch64")]
411 {
412 unsafe {
413 let ptrs = [
415 selected_addrs[0] as *const u8,
416 selected_addrs[1] as *const u8,
417 selected_addrs[2] as *const u8,
418 selected_addrs[3] as *const u8,
419 selected_addrs[4] as *const u8,
420 selected_addrs[5] as *const u8,
421 selected_addrs[6] as *const u8,
422 selected_addrs[7] as *const u8,
423 ];
424
425 match L::HINT {
427 0 => {
428 asm!("prfm pldl1strm, [{0}]", in(reg) ptrs[0]);
430 asm!("prfm pldl1strm, [{0}]", in(reg) ptrs[1]);
431 asm!("prfm pldl1strm, [{0}]", in(reg) ptrs[2]);
432 asm!("prfm pldl1strm, [{0}]", in(reg) ptrs[3]);
433 asm!("prfm pldl1strm, [{0}]", in(reg) ptrs[4]);
434 asm!("prfm pldl1strm, [{0}]", in(reg) ptrs[5]);
435 asm!("prfm pldl1strm, [{0}]", in(reg) ptrs[6]);
436 asm!("prfm pldl1strm, [{0}]", in(reg) ptrs[7]);
437 }
438 1 => {
439 asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[0]);
441 asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[1]);
442 asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[2]);
443 asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[3]);
444 asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[4]);
445 asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[5]);
446 asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[6]);
447 asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[7]);
448 }
449 2 => {
450 asm!("prfm pldl2keep, [{0}]", in(reg) ptrs[0]);
452 asm!("prfm pldl2keep, [{0}]", in(reg) ptrs[1]);
453 asm!("prfm pldl2keep, [{0}]", in(reg) ptrs[2]);
454 asm!("prfm pldl2keep, [{0}]", in(reg) ptrs[3]);
455 asm!("prfm pldl2keep, [{0}]", in(reg) ptrs[4]);
456 asm!("prfm pldl2keep, [{0}]", in(reg) ptrs[5]);
457 asm!("prfm pldl2keep, [{0}]", in(reg) ptrs[6]);
458 asm!("prfm pldl2keep, [{0}]", in(reg) ptrs[7]);
459 }
460 3 => {
461 asm!("prfm pldl3keep, [{0}]", in(reg) ptrs[0]);
463 asm!("prfm pldl3keep, [{0}]", in(reg) ptrs[1]);
464 asm!("prfm pldl3keep, [{0}]", in(reg) ptrs[2]);
465 asm!("prfm pldl3keep, [{0}]", in(reg) ptrs[3]);
466 asm!("prfm pldl3keep, [{0}]", in(reg) ptrs[4]);
467 asm!("prfm pldl3keep, [{0}]", in(reg) ptrs[5]);
468 asm!("prfm pldl3keep, [{0}]", in(reg) ptrs[6]);
469 asm!("prfm pldl3keep, [{0}]", in(reg) ptrs[7]);
470 }
471 _ => {
472 asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[0]);
474 asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[1]);
475 asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[2]);
476 asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[3]);
477 asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[4]);
478 asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[5]);
479 asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[6]);
480 asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[7]);
481 }
482 }
483 }
484 }
485
486 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
487 {
488 let _ = selected_addrs;
490 }
491}
492
493#[cfg(test)]
494mod tests {
495 use super::*;
496
497 #[test]
498 fn test_cache_level_constants() {
499 assert_eq!(NTA::HINT >= 0, true);
501 assert_eq!(L1::HINT >= 0, true);
502 assert_eq!(L2::HINT >= 0, true);
503 assert_eq!(L3::HINT >= 0, true);
504 }
505
506 #[test]
507 fn test_prefetch_single_address() {
508 let data = vec![0u32; 100];
509
510 prefetch_address::<_, NTA>(&data, 10);
512 prefetch_address::<_, L1>(&data, 20);
513 prefetch_address::<_, L2>(&data, 30);
514 prefetch_address::<_, L3>(&data, 40);
515 }
516
517 #[test]
518 fn test_prefetch_eight_addresses() {
519 let data = vec![0u32; 100];
520 let offsets = [10, 20, 30, 40, 50, 60, 70, 80];
521
522 prefetch_eight_offsets::<_, NTA>(&data, &offsets);
524 prefetch_eight_offsets::<_, L1>(&data, &offsets);
525 prefetch_eight_offsets::<_, L2>(&data, &offsets);
526 prefetch_eight_offsets::<_, L3>(&data, &offsets);
527 }
528
529 #[test]
530 fn test_prefetch_eight_masked() {
531 let data = vec![0u32; 100];
532 let offsets = [10, 20, 30, 40, 50, 60, 70, 80];
533
534 prefetch_eight_masked::<_, L1>(&data, offsets, 0xFF); prefetch_eight_masked::<_, L1>(&data, offsets, 0x00); prefetch_eight_masked::<_, L1>(&data, offsets, 0xAA); prefetch_eight_masked::<_, L1>(&data, offsets, 0x55); prefetch_eight_masked::<_, L1>(&data, offsets, 0x0F); prefetch_eight_masked::<_, L1>(&data, offsets, 0xF0); }
542
543 #[test]
544 fn test_different_data_types() {
545 let u32_data = vec![0u32; 100];
546 let u64_data = vec![0u64; 100];
547 let f32_data = vec![0.0f32; 100];
548
549 let offsets = [1, 2, 3, 4, 5, 6, 7, 8];
550
551 prefetch_eight_offsets::<_, L1>(&u32_data, &offsets);
552 prefetch_eight_offsets::<_, L1>(&u64_data, &offsets);
553 prefetch_eight_offsets::<_, L1>(&f32_data, &offsets);
554
555 prefetch_eight_masked::<_, L1>(&u32_data, offsets, 0xFF);
556 prefetch_eight_masked::<_, L1>(&u64_data, offsets, 0xAA);
557 prefetch_eight_masked::<_, L1>(&f32_data, offsets, 0x55);
558 }
559}