1use crate::error::TruenoError;
6
7pub const CACHE_LINE_SIZE: usize = 64;
13
14pub const CACHE_LINE_SIZE_F32: usize = CACHE_LINE_SIZE / std::mem::size_of::<f32>();
16
17#[repr(align(64))]
28#[derive(Debug)]
29pub struct CacheAligned<T>(pub T);
30
31impl<T> CacheAligned<T> {
32 pub const fn new(value: T) -> Self {
34 Self(value)
35 }
36
37 pub fn get(&self) -> &T {
39 &self.0
40 }
41
42 pub fn get_mut(&mut self) -> &mut T {
44 &mut self.0
45 }
46
47 pub fn into_inner(self) -> T {
49 self.0
50 }
51}
52
53impl<T: Default> Default for CacheAligned<T> {
54 fn default() -> Self {
55 Self(T::default())
56 }
57}
58
59impl<T: Clone> Clone for CacheAligned<T> {
60 fn clone(&self) -> Self {
61 Self(self.0.clone())
62 }
63}
64
65pub const DIRECT_IO_ALIGNMENT: usize = 4096;
71
72#[must_use]
74pub fn is_direct_io_aligned<T>(ptr: *const T) -> bool {
75 (ptr as usize).is_multiple_of(DIRECT_IO_ALIGNMENT)
76}
77
78#[cfg(not(target_arch = "wasm32"))]
80pub struct AlignedBuffer {
81 ptr: *mut u8,
82 len: usize,
83 layout: std::alloc::Layout,
84}
85
86#[cfg(not(target_arch = "wasm32"))]
87impl AlignedBuffer {
88 pub fn new(size: usize) -> Result<Self, TruenoError> {
93 use std::alloc::{alloc_zeroed, Layout};
94
95 let layout = Layout::from_size_align(size, DIRECT_IO_ALIGNMENT)
96 .map_err(|e| TruenoError::InvalidInput(format!("invalid alignment: {e}")))?;
97
98 let ptr = unsafe { alloc_zeroed(layout) };
100 if ptr.is_null() {
101 return Err(TruenoError::InvalidInput("allocation failed".into()));
102 }
103
104 Ok(Self { ptr, len: size, layout })
105 }
106
107 pub fn as_slice(&self) -> &[u8] {
109 unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
112 }
113
114 pub fn as_mut_slice(&mut self) -> &mut [u8] {
116 unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
119 }
120
121 pub fn as_ptr(&self) -> *const u8 {
123 self.ptr
124 }
125
126 pub fn as_mut_ptr(&mut self) -> *mut u8 {
128 self.ptr
129 }
130
131 pub fn len(&self) -> usize {
133 self.len
134 }
135
136 pub fn is_empty(&self) -> bool {
138 self.len == 0
139 }
140}
141
142#[cfg(not(target_arch = "wasm32"))]
143impl Drop for AlignedBuffer {
144 fn drop(&mut self) {
145 unsafe {
148 std::alloc::dealloc(self.ptr, self.layout);
149 }
150 }
151}
152
153#[cfg(not(target_arch = "wasm32"))]
154unsafe impl Send for AlignedBuffer {}
156
157#[cfg(not(target_arch = "wasm32"))]
158unsafe impl Sync for AlignedBuffer {}
160
161#[derive(Debug, Clone, Copy, PartialEq, Eq)]
167pub enum MemoryAdvice {
168 Sequential,
170 Random,
172 WillNeed,
174 DontNeed,
176}
177
178#[cfg(target_os = "linux")]
180const MADV_SEQUENTIAL: i32 = 2;
181#[cfg(target_os = "linux")]
182const MADV_RANDOM: i32 = 1;
183#[cfg(target_os = "linux")]
184const MADV_WILLNEED: i32 = 3;
185#[cfg(target_os = "linux")]
186const MADV_DONTNEED: i32 = 4;
187
188#[cfg(target_os = "linux")]
193pub unsafe fn madvise_region(
195 addr: *mut u8,
196 len: usize,
197 advice: MemoryAdvice,
198) -> std::io::Result<()> {
199 unsafe {
200 #[cfg(target_arch = "x86_64")]
202 const SYS_MADVISE: i64 = 28;
203 #[cfg(target_arch = "aarch64")]
204 const SYS_MADVISE: i64 = 233;
205
206 let advice_flag: i32 = match advice {
207 MemoryAdvice::Sequential => MADV_SEQUENTIAL,
208 MemoryAdvice::Random => MADV_RANDOM,
209 MemoryAdvice::WillNeed => MADV_WILLNEED,
210 MemoryAdvice::DontNeed => MADV_DONTNEED,
211 };
212
213 let ret: i64;
214 #[cfg(target_arch = "x86_64")]
215 {
216 core::arch::asm!(
217 "syscall",
218 inout("rax") SYS_MADVISE => ret,
219 in("rdi") addr as usize,
220 in("rsi") len,
221 in("rdx") advice_flag as i64,
222 out("rcx") _,
223 out("r11") _,
224 options(nostack)
225 );
226 }
227 #[cfg(target_arch = "aarch64")]
228 {
229 core::arch::asm!(
230 "svc 0",
231 inout("x8") SYS_MADVISE => _,
232 inout("x0") addr as usize => ret,
233 in("x1") len,
234 in("x2") advice_flag as i64,
235 options(nostack)
236 );
237 }
238
239 if ret < 0 {
240 return Err(std::io::Error::from_raw_os_error(-ret as i32));
241 }
242
243 Ok(())
244 }
245}
246
247#[cfg(not(target_os = "linux"))]
249pub unsafe fn madvise_region(
251 _addr: *mut u8,
252 _len: usize,
253 _advice: MemoryAdvice,
254) -> std::io::Result<()> {
255 Ok(()) }
257
258#[cfg(target_os = "linux")]
267pub unsafe fn prefetch_for_inference(addr: *mut u8, len: usize) -> std::io::Result<()> {
269 unsafe {
270 madvise_region(addr, len, MemoryAdvice::WillNeed)?;
272 madvise_region(addr, len, MemoryAdvice::Random)?;
274 Ok(())
275 }
276}
277
278#[cfg(not(target_os = "linux"))]
280pub unsafe fn prefetch_for_inference(_addr: *mut u8, _len: usize) -> std::io::Result<()> {
282 Ok(()) }
284
285#[derive(Debug, Clone, Copy, PartialEq, Eq)]
291pub enum PrefetchLocality {
292 None = 0,
294 Low = 1,
296 Moderate = 2,
298 High = 3,
300}
301
302#[inline]
307#[cfg(target_arch = "x86_64")]
308pub unsafe fn prefetch_ptr<T>(ptr: *const T, locality: PrefetchLocality) {
310 unsafe {
311 use core::arch::x86_64::*;
312 match locality {
313 PrefetchLocality::None => _mm_prefetch(ptr as *const i8, _MM_HINT_NTA),
314 PrefetchLocality::Low => _mm_prefetch(ptr as *const i8, _MM_HINT_T2),
315 PrefetchLocality::Moderate => _mm_prefetch(ptr as *const i8, _MM_HINT_T1),
316 PrefetchLocality::High => _mm_prefetch(ptr as *const i8, _MM_HINT_T0),
317 }
318 }
319}
320
321#[inline]
323#[cfg(target_arch = "aarch64")]
324pub unsafe fn prefetch_ptr<T>(ptr: *const T, _locality: PrefetchLocality) {
326 core::arch::asm!(
328 "prfm pldl1keep, [{ptr}]",
329 ptr = in(reg) ptr,
330 options(nostack, preserves_flags)
331 );
332}
333
334#[inline]
336#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
337pub unsafe fn prefetch_ptr<T>(_ptr: *const T, _locality: PrefetchLocality) {
339 }
341
342#[inline]
346pub fn prefetch_slice<T>(slice: &[T], locality: PrefetchLocality) {
347 let ptr = slice.as_ptr() as *const u8;
348 let len = std::mem::size_of_val(slice);
349
350 for offset in (0..len).step_by(CACHE_LINE_SIZE) {
351 unsafe {
354 prefetch_ptr(ptr.add(offset), locality);
355 }
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362
363 #[test]
364 fn test_cache_aligned_alignment() {
365 let aligned: CacheAligned<u64> = CacheAligned::new(42);
366 assert_eq!(std::mem::align_of_val(&aligned), 64);
367 }
368
369 #[test]
370 fn test_cache_aligned_value() {
371 let aligned = CacheAligned::new(42u64);
372 assert_eq!(*aligned.get(), 42);
373 }
374
375 #[test]
376 fn test_cache_aligned_get_mut() {
377 let mut aligned = CacheAligned::new(42u64);
378 *aligned.get_mut() = 100;
379 assert_eq!(*aligned.get(), 100);
380 }
381
382 #[test]
383 fn test_cache_aligned_into_inner() {
384 let aligned = CacheAligned::new(42u64);
385 assert_eq!(aligned.into_inner(), 42);
386 }
387
388 #[test]
389 fn test_cache_aligned_default() {
390 let aligned: CacheAligned<u64> = CacheAligned::default();
391 assert_eq!(*aligned.get(), 0);
392 }
393
394 #[test]
395 fn test_cache_aligned_clone() {
396 let aligned = CacheAligned::new(42u64);
397 let cloned = aligned.clone();
398 assert_eq!(*cloned.get(), 42);
399 }
400
401 #[test]
402 fn test_cache_line_size_f32() {
403 assert_eq!(CACHE_LINE_SIZE_F32, 16); }
405
406 #[test]
407 fn test_direct_io_alignment() {
408 assert_eq!(DIRECT_IO_ALIGNMENT, 4096);
409 }
410
411 #[test]
412 fn test_is_direct_io_aligned() {
413 let aligned_addr: usize = 4096 * 10;
414 let unaligned_addr: usize = 4096 * 10 + 1;
415
416 assert!(is_direct_io_aligned(aligned_addr as *const u8));
417 assert!(!is_direct_io_aligned(unaligned_addr as *const u8));
418 }
419
420 #[cfg(not(target_arch = "wasm32"))]
421 #[test]
422 fn test_aligned_buffer_creation() {
423 let buffer = AlignedBuffer::new(4096).unwrap();
424 assert_eq!(buffer.len(), 4096);
425 assert!(!buffer.is_empty());
426 }
427
428 #[cfg(not(target_arch = "wasm32"))]
429 #[test]
430 fn test_aligned_buffer_zeroed() {
431 let buffer = AlignedBuffer::new(1024).unwrap();
432 let slice = buffer.as_slice();
433 assert!(slice.iter().all(|&b| b == 0));
434 }
435
436 #[cfg(not(target_arch = "wasm32"))]
437 #[test]
438 fn test_aligned_buffer_write() {
439 let mut buffer = AlignedBuffer::new(1024).unwrap();
440 buffer.as_mut_slice()[0] = 42;
441 assert_eq!(buffer.as_slice()[0], 42);
442 }
443
444 #[test]
445 fn test_memory_advice_eq() {
446 assert_eq!(MemoryAdvice::Sequential, MemoryAdvice::Sequential);
447 assert_ne!(MemoryAdvice::Sequential, MemoryAdvice::Random);
448 }
449
450 #[test]
451 fn test_prefetch_locality_values() {
452 assert_eq!(PrefetchLocality::None as u8, 0);
453 assert_eq!(PrefetchLocality::Low as u8, 1);
454 assert_eq!(PrefetchLocality::Moderate as u8, 2);
455 assert_eq!(PrefetchLocality::High as u8, 3);
456 }
457
458 #[test]
459 fn test_prefetch_slice_empty() {
460 let empty: &[f32] = &[];
461 prefetch_slice(empty, PrefetchLocality::High);
462 }
464
465 #[test]
466 fn test_prefetch_slice_small() {
467 let data = [1.0f32; 8];
468 prefetch_slice(&data, PrefetchLocality::High);
469 }
471
472 #[test]
473 fn test_madvise_region_stub() {
474 unsafe {
477 let mut data = [0u8; 4096];
478 let _result = madvise_region(data.as_mut_ptr(), data.len(), MemoryAdvice::WillNeed);
479 #[cfg(not(target_os = "linux"))]
480 assert!(_result.is_ok());
481 }
482 }
483
484 #[test]
485 fn test_prefetch_for_inference_stub() {
486 unsafe {
488 let mut data = [0u8; 4096];
489 let _result = prefetch_for_inference(data.as_mut_ptr(), data.len());
490 #[cfg(not(target_os = "linux"))]
491 assert!(_result.is_ok());
492 }
493 }
494}