oxibonsai_kernels/
prefetch.rs1const DEFAULT_LOOKAHEAD_BLOCKS: usize = 4;
12
13#[derive(Debug, Clone)]
15pub struct PrefetchConfig {
16 pub lookahead_blocks: usize,
19 pub strategy: PrefetchStrategy,
21}
22
23impl Default for PrefetchConfig {
24 fn default() -> Self {
25 Self {
26 lookahead_blocks: DEFAULT_LOOKAHEAD_BLOCKS,
27 strategy: PrefetchStrategy::Temporal,
28 }
29 }
30}
31
32impl PrefetchConfig {
33 pub fn for_gemv() -> Self {
35 Self {
36 lookahead_blocks: 4,
37 strategy: PrefetchStrategy::Temporal,
38 }
39 }
40
41 pub fn for_gemm(batch_size: usize) -> Self {
48 if batch_size > 32 {
49 Self {
50 lookahead_blocks: 8,
51 strategy: PrefetchStrategy::NonTemporal,
52 }
53 } else {
54 Self {
55 lookahead_blocks: 4,
56 strategy: PrefetchStrategy::Temporal,
57 }
58 }
59 }
60
61 pub fn none() -> Self {
63 Self {
64 lookahead_blocks: 0,
65 strategy: PrefetchStrategy::None,
66 }
67 }
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum PrefetchStrategy {
73 None,
75 Temporal,
78 NonTemporal,
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub enum PrefetchLocality {
86 High,
88 Medium,
90 Low,
92}
93
94#[inline(always)]
105pub fn prefetch_read<T>(ptr: *const T, locality: PrefetchLocality) {
106 #[cfg(target_arch = "x86_64")]
108 {
109 prefetch_read_x86(ptr.cast::<i8>(), locality);
110 }
111
112 #[cfg(target_arch = "aarch64")]
114 {
115 prefetch_read_aarch64(ptr.cast::<i8>(), locality);
116 }
117
118 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
120 {
121 let _ = ptr;
122 let _ = locality;
123 }
124}
125
126#[inline(always)]
131pub fn prefetch_write<T>(ptr: *mut T, locality: PrefetchLocality) {
132 #[cfg(target_arch = "x86_64")]
133 {
134 prefetch_write_x86(ptr.cast::<i8>(), locality);
135 }
136
137 #[cfg(target_arch = "aarch64")]
138 {
139 prefetch_write_aarch64(ptr.cast::<i8>(), locality);
140 }
141
142 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
143 {
144 let _ = ptr;
145 let _ = locality;
146 }
147}
148
149#[inline]
153pub fn prefetch_range_read<T>(ptr: *const T, byte_count: usize, locality: PrefetchLocality) {
154 let cache_line = 64usize;
155 let mut offset = 0;
156 while offset < byte_count {
157 let addr = unsafe { (ptr as *const u8).add(offset) };
159 prefetch_read(addr, locality);
160 offset += cache_line;
161 }
162}
163
164#[cfg(target_arch = "x86_64")]
167#[inline(always)]
168fn prefetch_read_x86(ptr: *const i8, locality: PrefetchLocality) {
169 unsafe {
171 match locality {
172 PrefetchLocality::High => {
173 core::arch::x86_64::_mm_prefetch(ptr, core::arch::x86_64::_MM_HINT_T0);
174 }
175 PrefetchLocality::Medium => {
176 core::arch::x86_64::_mm_prefetch(ptr, core::arch::x86_64::_MM_HINT_T1);
177 }
178 PrefetchLocality::Low => {
179 core::arch::x86_64::_mm_prefetch(ptr, core::arch::x86_64::_MM_HINT_NTA);
180 }
181 }
182 }
183}
184
185#[cfg(target_arch = "x86_64")]
186#[inline(always)]
187fn prefetch_write_x86(ptr: *const i8, locality: PrefetchLocality) {
188 prefetch_read_x86(ptr, locality);
193}
194
195#[cfg(target_arch = "aarch64")]
198#[inline(always)]
199fn prefetch_read_aarch64(ptr: *const i8, locality: PrefetchLocality) {
200 unsafe {
203 match locality {
204 PrefetchLocality::High => {
205 core::arch::aarch64::_prefetch(ptr, 0, 3); }
207 PrefetchLocality::Medium => {
208 core::arch::aarch64::_prefetch(ptr, 0, 2); }
210 PrefetchLocality::Low => {
211 core::arch::aarch64::_prefetch(ptr, 0, 0); }
213 }
214 }
215}
216
217#[cfg(target_arch = "aarch64")]
218#[inline(always)]
219fn prefetch_write_aarch64(ptr: *const i8, locality: PrefetchLocality) {
220 unsafe {
222 match locality {
223 PrefetchLocality::High => {
224 core::arch::aarch64::_prefetch(ptr, 1, 3);
225 }
226 PrefetchLocality::Medium => {
227 core::arch::aarch64::_prefetch(ptr, 1, 2);
228 }
229 PrefetchLocality::Low => {
230 core::arch::aarch64::_prefetch(ptr, 1, 0);
231 }
232 }
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[test]
241 fn prefetch_config_defaults() {
242 let config = PrefetchConfig::default();
243 assert_eq!(config.lookahead_blocks, 4);
244 assert_eq!(config.strategy, PrefetchStrategy::Temporal);
245 }
246
247 #[test]
248 fn prefetch_config_for_gemv() {
249 let config = PrefetchConfig::for_gemv();
250 assert_eq!(config.strategy, PrefetchStrategy::Temporal);
251 assert!(config.lookahead_blocks > 0);
252 }
253
254 #[test]
255 fn prefetch_config_for_gemm_small_batch() {
256 let config = PrefetchConfig::for_gemm(4);
257 assert_eq!(config.strategy, PrefetchStrategy::Temporal);
258 }
259
260 #[test]
261 fn prefetch_config_for_gemm_large_batch() {
262 let config = PrefetchConfig::for_gemm(64);
263 assert_eq!(config.strategy, PrefetchStrategy::NonTemporal);
264 assert!(config.lookahead_blocks > 4);
265 }
266
267 #[test]
268 fn prefetch_config_none() {
269 let config = PrefetchConfig::none();
270 assert_eq!(config.lookahead_blocks, 0);
271 assert_eq!(config.strategy, PrefetchStrategy::None);
272 }
273
274 #[test]
275 fn prefetch_read_smoke_test() {
276 let data = [1.0f32, 2.0, 3.0, 4.0];
278 prefetch_read(data.as_ptr(), PrefetchLocality::High);
279 prefetch_read(data.as_ptr(), PrefetchLocality::Medium);
280 prefetch_read(data.as_ptr(), PrefetchLocality::Low);
281 }
282
283 #[test]
284 fn prefetch_write_smoke_test() {
285 let mut data = [0.0f32; 16];
286 prefetch_write(data.as_mut_ptr(), PrefetchLocality::High);
287 prefetch_write(data.as_mut_ptr(), PrefetchLocality::Medium);
288 prefetch_write(data.as_mut_ptr(), PrefetchLocality::Low);
289 data[0] = 42.0;
291 assert!((data[0] - 42.0).abs() < f32::EPSILON);
292 }
293
294 #[test]
295 fn prefetch_range_read_smoke_test() {
296 let data = vec![0.0f32; 1024];
297 let byte_count = data.len() * std::mem::size_of::<f32>();
298 prefetch_range_read(data.as_ptr(), byte_count, PrefetchLocality::High);
299 prefetch_range_read(data.as_ptr(), byte_count, PrefetchLocality::Low);
300 }
301
302 #[test]
303 fn prefetch_strategy_equality() {
304 assert_eq!(PrefetchStrategy::None, PrefetchStrategy::None);
305 assert_ne!(PrefetchStrategy::Temporal, PrefetchStrategy::NonTemporal);
306 }
307
308 #[test]
309 fn prefetch_locality_equality() {
310 assert_eq!(PrefetchLocality::High, PrefetchLocality::High);
311 assert_ne!(PrefetchLocality::High, PrefetchLocality::Low);
312 }
313}