1#![cfg_attr(
7 all(feature = "aarch64", target_arch = "aarch64"),
8 feature(stdarch_aarch64_prefetch)
9)]
10
11#[inline(always)]
26pub fn prefetch_index<T>(data: impl AsRef<[T]>, index: usize) {
27 let ptr = data.as_ref().as_ptr().wrapping_add(index) as *const i8;
28 #[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
29 unsafe {
30 std::arch::x86_64::_mm_prefetch(ptr, std::arch::x86_64::_MM_HINT_T0);
31 }
32 #[cfg(all(target_arch = "x86", target_feature = "sse"))]
33 unsafe {
34 std::arch::x86::_mm_prefetch(ptr, std::arch::x86::_MM_HINT_T0);
35 }
36 #[cfg(all(target_arch = "aarch64", feature = "aarch64"))]
37 unsafe {
38 std::arch::aarch64::_prefetch::<
39 { std::arch::aarch64::_PREFETCH_READ },
40 { std::arch::aarch64::_PREFETCH_LOCALITY3 },
41 >(ptr);
42 }
43 #[cfg(not(any(
44 all(target_arch = "x86_64", target_feature = "sse"),
45 all(target_arch = "x86", target_feature = "sse"),
46 all(target_arch = "aarch64", feature = "aarch64")
47 )))]
48 {
49 let _ = ptr; }
51}
52
53#[inline(always)]
61pub fn prefetch_index_nta<T>(data: impl AsRef<[T]>, index: usize) {
62 let ptr = data.as_ref().as_ptr().wrapping_add(index) as *const i8;
63 #[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
64 unsafe {
65 std::arch::x86_64::_mm_prefetch(ptr, std::arch::x86_64::_MM_HINT_NTA);
66 }
67 #[cfg(all(target_arch = "x86", target_feature = "sse"))]
68 unsafe {
69 std::arch::x86::_mm_prefetch(ptr, std::arch::x86::_MM_HINT_NTA);
70 }
71 #[cfg(all(target_arch = "aarch64", feature = "aarch64"))]
72 unsafe {
73 std::arch::aarch64::_prefetch::<
74 { std::arch::aarch64::_PREFETCH_READ },
75 { std::arch::aarch64::_PREFETCH_LOCALITY0 },
76 >(ptr);
77 }
78 #[cfg(not(any(
79 all(target_arch = "x86_64", target_feature = "sse"),
80 all(target_arch = "x86", target_feature = "sse"),
81 all(target_arch = "aarch64", feature = "aarch64")
82 )))]
83 {
84 let _ = ptr; }
86}
87
88#[cfg(test)]
89mod tests {
90 use std::{array::from_fn, hint::black_box};
91
92 use super::*;
93
94 #[test]
95 fn basic() {
96 let data = vec![0u8; 1024];
97 prefetch_index(&data, 512);
98 prefetch_index_nta(&data, 512);
99 }
100
101 fn setup() -> Vec<u64> {
103 let len = 1024 * 1024 * 1024 / 8;
104 let mut data = vec![0u64; len];
105 let mut idx = 0;
108 for i in 0..len {
109 let next = (idx + (i + 1)) % len;
110 assert_eq!(data[idx], 0);
111 data[idx] = next as u64;
112 idx = next;
113 }
114 data
115 }
116
117 #[test]
118 fn out_of_bounds_is_ok() {
119 for j in 0..20 {
120 let len = 1 << j;
121 let data = vec![0u64; len];
122 for i in 0..100usize {
123 prefetch_index(&data, i.wrapping_neg());
124 prefetch_index(&data, len + i);
125 }
126 prefetch_index(&data, usize::MAX);
127 prefetch_index(&data, len + usize::MAX / 2);
128 }
129 }
130
131 #[test]
132 fn batched_pointer_chasing_prefetch() {
133 const B: usize = 32;
135
136 let data = setup();
137
138 let time = |prefetch: bool| {
139 let mut indices: [usize; B] = from_fn(|i| i * data.len() / B + i);
140 let start = std::time::Instant::now();
141 let mut sum = 0;
142 for _ in 0..data.len() / B {
143 for idx in &mut indices {
144 *idx = data[*idx] as usize;
145 if prefetch {
146 prefetch_index(&data, *idx);
147 }
148 let mut x: usize = 1;
150 for _ in 0..16 {
151 x = x.wrapping_mul(*idx);
152 sum += x;
153 }
154 }
155 }
156 black_box(sum);
157 let duration = start.elapsed();
158 let ns_per_it = duration.as_nanos() as f64 / (data.len() as f64);
159 eprintln!(
160 "Prefetch={prefetch:<5}: {:?} {:5.2} ns/it",
161 duration, ns_per_it
162 );
163 ns_per_it
164 };
165 let no_prefetch = time(false);
166 let with_prefetch = time(true);
167 assert!(with_prefetch < 0.8 * no_prefetch, "Time with prefetching {with_prefetch} is not sufficiently smaller than time without prefetching {no_prefetch}");
168 }
169}