Skip to main content

prefetch_index/
lib.rs

1//! A small crate to prefetch an element of an array.
2//!
3//! Provides [`prefetch_index`] and [`prefetch_index_nta`] to prefetch the cache
4//! line containing `slice[index]`.
5
6#![cfg_attr(
7    all(feature = "aarch64", target_arch = "aarch64"),
8    feature(stdarch_aarch64_prefetch)
9)]
10
11/// Prefetches the cache line containing (the first byte of) `data[index]` into
12/// _all_ levels of the cache.
13///
14/// On x86/x86_64, this uses _mm_prefetch which only requires the (commonly available) SSE feature.
15///
16/// On aarch64, this is gated behind the `aarch64` feature flag as `aarch64::_prefetch` is nightly.
17///
18/// On other architectures, this is a no-op.
19///
20/// ```
21/// use prefetch_index::prefetch_index;
22/// let data = vec![0u8; 1024];
23/// prefetch_index(&data, 512);
24/// ```
25#[inline(always)]
26pub fn prefetch_index<T>(data: impl AsRef<[T]>, index: usize) {
27    let ptr = data.as_ref().as_ptr().wrapping_add(index) as *const i8;
28    #[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
29    unsafe {
30        std::arch::x86_64::_mm_prefetch(ptr, std::arch::x86_64::_MM_HINT_T0);
31    }
32    #[cfg(all(target_arch = "x86", target_feature = "sse"))]
33    unsafe {
34        std::arch::x86::_mm_prefetch(ptr, std::arch::x86::_MM_HINT_T0);
35    }
36    #[cfg(all(target_arch = "aarch64", feature = "aarch64"))]
37    unsafe {
38        std::arch::aarch64::_prefetch::<
39            { std::arch::aarch64::_PREFETCH_READ },
40            { std::arch::aarch64::_PREFETCH_LOCALITY3 },
41        >(ptr);
42    }
43    #[cfg(not(any(
44        all(target_arch = "x86_64", target_feature = "sse"),
45        all(target_arch = "x86", target_feature = "sse"),
46        all(target_arch = "aarch64", feature = "aarch64")
47    )))]
48    {
49        let _ = ptr; // Silence unused variable warning.
50    }
51}
52
53/// Prefetches the cache line containing (the first byte of) `data[index]` for non-temporal access.
54///
55/// On x86/x86_64, this uses _mm_prefetch which only requires the (commonly available) SSE feature.
56///
57/// On aarch64, this is gated behind the `aarch64` feature flag as `aarch64::_prefetch` is nightly.
58///
59/// On other architectures, this is a no-op.
60#[inline(always)]
61pub fn prefetch_index_nta<T>(data: impl AsRef<[T]>, index: usize) {
62    let ptr = data.as_ref().as_ptr().wrapping_add(index) as *const i8;
63    #[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
64    unsafe {
65        std::arch::x86_64::_mm_prefetch(ptr, std::arch::x86_64::_MM_HINT_NTA);
66    }
67    #[cfg(all(target_arch = "x86", target_feature = "sse"))]
68    unsafe {
69        std::arch::x86::_mm_prefetch(ptr, std::arch::x86::_MM_HINT_NTA);
70    }
71    #[cfg(all(target_arch = "aarch64", feature = "aarch64"))]
72    unsafe {
73        std::arch::aarch64::_prefetch::<
74            { std::arch::aarch64::_PREFETCH_READ },
75            { std::arch::aarch64::_PREFETCH_LOCALITY0 },
76        >(ptr);
77    }
78    #[cfg(not(any(
79        all(target_arch = "x86_64", target_feature = "sse"),
80        all(target_arch = "x86", target_feature = "sse"),
81        all(target_arch = "aarch64", feature = "aarch64")
82    )))]
83    {
84        let _ = ptr; // Silence unused variable warning.
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use std::{array::from_fn, hint::black_box};
91
92    use super::*;
93
94    #[test]
95    fn basic() {
96        let data = vec![0u8; 1024];
97        prefetch_index(&data, 512);
98        prefetch_index_nta(&data, 512);
99    }
100
101    /// Returns a large vector with each index pointing to the next cache line.
102    fn setup() -> Vec<u64> {
103        let len = 1024 * 1024 * 1024 / 8;
104        let mut data = vec![0u64; len];
105        // triangular stride hits all position:
106        // 0->1->3->6->...
107        let mut idx = 0;
108        for i in 0..len {
109            let next = (idx + (i + 1)) % len;
110            assert_eq!(data[idx], 0);
111            data[idx] = next as u64;
112            idx = next;
113        }
114        data
115    }
116
117    #[test]
118    fn out_of_bounds_is_ok() {
119        for j in 0..20 {
120            let len = 1 << j;
121            let data = vec![0u64; len];
122            for i in 0..100usize {
123                prefetch_index(&data, i.wrapping_neg());
124                prefetch_index(&data, len + i);
125            }
126            prefetch_index(&data, usize::MAX);
127            prefetch_index(&data, len + usize::MAX / 2);
128        }
129    }
130
131    #[test]
132    fn batched_pointer_chasing_prefetch() {
133        // batch size
134        const B: usize = 32;
135
136        let data = setup();
137
138        let time = |prefetch: bool| {
139            let mut indices: [usize; B] = from_fn(|i| i * data.len() / B + i);
140            let start = std::time::Instant::now();
141            let mut sum = 0;
142            for _ in 0..data.len() / B {
143                for idx in &mut indices {
144                    *idx = data[*idx] as usize;
145                    if prefetch {
146                        prefetch_index(&data, *idx);
147                    }
148                    // Simulate some work to fill the CPU reorder buffer.
149                    let mut x: usize = 1;
150                    for _ in 0..16 {
151                        x = x.wrapping_mul(*idx);
152                        sum += x;
153                    }
154                }
155            }
156            black_box(sum);
157            let duration = start.elapsed();
158            let ns_per_it = duration.as_nanos() as f64 / (data.len() as f64);
159            eprintln!(
160                "Prefetch={prefetch:<5}:    {:?}  {:5.2} ns/it",
161                duration, ns_per_it
162            );
163            ns_per_it
164        };
165        let no_prefetch = time(false);
166        let with_prefetch = time(true);
167        assert!(with_prefetch < 0.8 * no_prefetch, "Time with prefetching {with_prefetch} is not sufficiently smaller than time without prefetching {no_prefetch}");
168    }
169}