Skip to main content

oxibonsai_kernels/
prefetch.rs

1//! Software prefetch hints for GEMV/GEMM kernel operations.
2//!
3//! Provides platform-abstracted prefetch intrinsics that compile to
4//! the appropriate hardware instruction on x86-64 (`_mm_prefetch`) and
5//! AArch64 (`__prefetch`), and are no-ops on platforms without support.
6//!
7//! These hints allow the CPU to begin loading cache lines before they
8//! are needed, hiding memory latency in compute-bound loops.
9
10/// Number of blocks to prefetch ahead in GEMV/GEMM loops.
11const DEFAULT_LOOKAHEAD_BLOCKS: usize = 4;
12
13/// Configuration for prefetch behavior in kernel loops.
14#[derive(Debug, Clone)]
15pub struct PrefetchConfig {
16    /// How many blocks ahead to prefetch in the inner loop.
17    /// Higher values hide more latency but consume more cache.
18    pub lookahead_blocks: usize,
19    /// Which prefetch strategy to use.
20    pub strategy: PrefetchStrategy,
21}
22
23impl Default for PrefetchConfig {
24    fn default() -> Self {
25        Self {
26            lookahead_blocks: DEFAULT_LOOKAHEAD_BLOCKS,
27            strategy: PrefetchStrategy::Temporal,
28        }
29    }
30}
31
32impl PrefetchConfig {
33    /// Create a config optimized for GEMV (single vector, temporal reuse of weights).
34    pub fn for_gemv() -> Self {
35        Self {
36            lookahead_blocks: 4,
37            strategy: PrefetchStrategy::Temporal,
38        }
39    }
40
41    /// Create a config optimized for GEMM (batch, streaming weights).
42    ///
43    /// In GEMM, each weight block is reused across the M dimension,
44    /// so temporal locality is still useful. For very large M, however,
45    /// the first-touch of weight blocks benefits from non-temporal prefetch
46    /// to avoid polluting L1 with data that won't be reused for many iterations.
47    pub fn for_gemm(batch_size: usize) -> Self {
48        if batch_size > 32 {
49            Self {
50                lookahead_blocks: 8,
51                strategy: PrefetchStrategy::NonTemporal,
52            }
53        } else {
54            Self {
55                lookahead_blocks: 4,
56                strategy: PrefetchStrategy::Temporal,
57            }
58        }
59    }
60
61    /// No prefetching (baseline for benchmarking).
62    pub fn none() -> Self {
63        Self {
64            lookahead_blocks: 0,
65            strategy: PrefetchStrategy::None,
66        }
67    }
68}
69
70/// Prefetch strategy controlling cache line placement.
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum PrefetchStrategy {
73    /// No software prefetch hints issued.
74    None,
75    /// Prefetch for temporal locality — data goes to L1 cache.
76    /// Best when data will be reused soon (e.g., weight blocks reused across batch).
77    Temporal,
78    /// Prefetch for non-temporal (streaming) access — data goes to L2/L3.
79    /// Best when data is used once then evicted (e.g., large streaming loads).
80    NonTemporal,
81}
82
83/// Prefetch locality hint, controlling which cache level receives the data.
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub enum PrefetchLocality {
86    /// Data will be reused imminently — prefetch to L1 (closest cache).
87    High,
88    /// Data might be reused — prefetch to L2.
89    Medium,
90    /// Data unlikely to be reused — prefetch to L3 or use non-temporal hint.
91    Low,
92}
93
94/// Issue a software prefetch hint for a read from the given pointer.
95///
96/// This is a performance hint only — the CPU may ignore it. On platforms
97/// without prefetch support, this is a no-op that compiles to nothing.
98///
99/// # Safety note
100///
101/// The pointer does not need to be valid (prefetch of invalid addresses
102/// is architecturally a no-op on x86 and ARM), but callers should ensure
103/// the address is within a reasonable range to avoid TLB pollution.
104#[inline(always)]
105pub fn prefetch_read<T>(ptr: *const T, locality: PrefetchLocality) {
106    // x86-64: _mm_prefetch
107    #[cfg(target_arch = "x86_64")]
108    {
109        prefetch_read_x86(ptr.cast::<i8>(), locality);
110    }
111
112    // AArch64: __prefetch
113    #[cfg(target_arch = "aarch64")]
114    {
115        prefetch_read_aarch64(ptr.cast::<i8>(), locality);
116    }
117
118    // All other platforms: no-op
119    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
120    {
121        let _ = ptr;
122        let _ = locality;
123    }
124}
125
126/// Issue a software prefetch hint for a write to the given pointer.
127///
128/// Tells the CPU to fetch the cache line in exclusive/modified state,
129/// which avoids a read-for-ownership transaction on the first write.
130#[inline(always)]
131pub fn prefetch_write<T>(ptr: *mut T, locality: PrefetchLocality) {
132    #[cfg(target_arch = "x86_64")]
133    {
134        prefetch_write_x86(ptr.cast::<i8>(), locality);
135    }
136
137    #[cfg(target_arch = "aarch64")]
138    {
139        prefetch_write_aarch64(ptr.cast::<i8>(), locality);
140    }
141
142    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
143    {
144        let _ = ptr;
145        let _ = locality;
146    }
147}
148
149/// Prefetch a sequence of `count` cache lines starting from `ptr`.
150///
151/// Useful for prefetching a contiguous array of blocks before processing.
152#[inline]
153pub fn prefetch_range_read<T>(ptr: *const T, byte_count: usize, locality: PrefetchLocality) {
154    let cache_line = 64usize;
155    let mut offset = 0;
156    while offset < byte_count {
157        // SAFETY: We're only issuing prefetch hints; invalid addresses are safe.
158        let addr = unsafe { (ptr as *const u8).add(offset) };
159        prefetch_read(addr, locality);
160        offset += cache_line;
161    }
162}
163
164// ── x86-64 implementation ───────────────────────────────────────────────
165
166#[cfg(target_arch = "x86_64")]
167#[inline(always)]
168fn prefetch_read_x86(ptr: *const i8, locality: PrefetchLocality) {
169    // SAFETY: _mm_prefetch is always safe — invalid addresses are silently ignored.
170    unsafe {
171        match locality {
172            PrefetchLocality::High => {
173                core::arch::x86_64::_mm_prefetch(ptr, core::arch::x86_64::_MM_HINT_T0);
174            }
175            PrefetchLocality::Medium => {
176                core::arch::x86_64::_mm_prefetch(ptr, core::arch::x86_64::_MM_HINT_T1);
177            }
178            PrefetchLocality::Low => {
179                core::arch::x86_64::_mm_prefetch(ptr, core::arch::x86_64::_MM_HINT_NTA);
180            }
181        }
182    }
183}
184
185#[cfg(target_arch = "x86_64")]
186#[inline(always)]
187fn prefetch_write_x86(ptr: *const i8, locality: PrefetchLocality) {
188    // x86 doesn't have a separate write prefetch in SSE — use PREFETCHW if available,
189    // otherwise fall back to read prefetch (which still helps).
190    // _mm_prefetch with _MM_HINT_ET0 is PREFETCHW (exclusive for write).
191    // Not all x86 CPUs support it, so we use read prefetch as a safe fallback.
192    prefetch_read_x86(ptr, locality);
193}
194
195// ── AArch64 implementation ──────────────────────────────────────────────
196
197#[cfg(target_arch = "aarch64")]
198#[inline(always)]
199fn prefetch_read_aarch64(ptr: *const i8, locality: PrefetchLocality) {
200    // SAFETY: __prefetch is safe — invalid addresses are silently ignored on ARM.
201    // AArch64 _prefetch requires const arguments, so we match and call separately.
202    unsafe {
203        match locality {
204            PrefetchLocality::High => {
205                core::arch::aarch64::_prefetch(ptr, 0, 3); // keep in all caches
206            }
207            PrefetchLocality::Medium => {
208                core::arch::aarch64::_prefetch(ptr, 0, 2); // keep in L2+
209            }
210            PrefetchLocality::Low => {
211                core::arch::aarch64::_prefetch(ptr, 0, 0); // non-temporal
212            }
213        }
214    }
215}
216
217#[cfg(target_arch = "aarch64")]
218#[inline(always)]
219fn prefetch_write_aarch64(ptr: *const i8, locality: PrefetchLocality) {
220    // SAFETY: rw=1 for write/store prefetch. Const arguments required.
221    unsafe {
222        match locality {
223            PrefetchLocality::High => {
224                core::arch::aarch64::_prefetch(ptr, 1, 3);
225            }
226            PrefetchLocality::Medium => {
227                core::arch::aarch64::_prefetch(ptr, 1, 2);
228            }
229            PrefetchLocality::Low => {
230                core::arch::aarch64::_prefetch(ptr, 1, 0);
231            }
232        }
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    #[test]
241    fn prefetch_config_defaults() {
242        let config = PrefetchConfig::default();
243        assert_eq!(config.lookahead_blocks, 4);
244        assert_eq!(config.strategy, PrefetchStrategy::Temporal);
245    }
246
247    #[test]
248    fn prefetch_config_for_gemv() {
249        let config = PrefetchConfig::for_gemv();
250        assert_eq!(config.strategy, PrefetchStrategy::Temporal);
251        assert!(config.lookahead_blocks > 0);
252    }
253
254    #[test]
255    fn prefetch_config_for_gemm_small_batch() {
256        let config = PrefetchConfig::for_gemm(4);
257        assert_eq!(config.strategy, PrefetchStrategy::Temporal);
258    }
259
260    #[test]
261    fn prefetch_config_for_gemm_large_batch() {
262        let config = PrefetchConfig::for_gemm(64);
263        assert_eq!(config.strategy, PrefetchStrategy::NonTemporal);
264        assert!(config.lookahead_blocks > 4);
265    }
266
267    #[test]
268    fn prefetch_config_none() {
269        let config = PrefetchConfig::none();
270        assert_eq!(config.lookahead_blocks, 0);
271        assert_eq!(config.strategy, PrefetchStrategy::None);
272    }
273
274    #[test]
275    fn prefetch_read_smoke_test() {
276        // Ensure calling prefetch_read doesn't crash
277        let data = [1.0f32, 2.0, 3.0, 4.0];
278        prefetch_read(data.as_ptr(), PrefetchLocality::High);
279        prefetch_read(data.as_ptr(), PrefetchLocality::Medium);
280        prefetch_read(data.as_ptr(), PrefetchLocality::Low);
281    }
282
283    #[test]
284    fn prefetch_write_smoke_test() {
285        let mut data = [0.0f32; 16];
286        prefetch_write(data.as_mut_ptr(), PrefetchLocality::High);
287        prefetch_write(data.as_mut_ptr(), PrefetchLocality::Medium);
288        prefetch_write(data.as_mut_ptr(), PrefetchLocality::Low);
289        // Should still be writable after prefetch
290        data[0] = 42.0;
291        assert!((data[0] - 42.0).abs() < f32::EPSILON);
292    }
293
294    #[test]
295    fn prefetch_range_read_smoke_test() {
296        let data = vec![0.0f32; 1024];
297        let byte_count = data.len() * std::mem::size_of::<f32>();
298        prefetch_range_read(data.as_ptr(), byte_count, PrefetchLocality::High);
299        prefetch_range_read(data.as_ptr(), byte_count, PrefetchLocality::Low);
300    }
301
302    #[test]
303    fn prefetch_strategy_equality() {
304        assert_eq!(PrefetchStrategy::None, PrefetchStrategy::None);
305        assert_ne!(PrefetchStrategy::Temporal, PrefetchStrategy::NonTemporal);
306    }
307
308    #[test]
309    fn prefetch_locality_equality() {
310        assert_eq!(PrefetchLocality::High, PrefetchLocality::High);
311        assert_ne!(PrefetchLocality::High, PrefetchLocality::Low);
312    }
313}