use crate::error::TruenoError;
pub const CACHE_LINE_SIZE: usize = 64;
pub const CACHE_LINE_SIZE_F32: usize = CACHE_LINE_SIZE / std::mem::size_of::<f32>();
#[repr(align(64))]
#[derive(Debug)]
pub struct CacheAligned<T>(pub T);
impl<T> CacheAligned<T> {
pub const fn new(value: T) -> Self {
Self(value)
}
pub fn get(&self) -> &T {
&self.0
}
pub fn get_mut(&mut self) -> &mut T {
&mut self.0
}
pub fn into_inner(self) -> T {
self.0
}
}
impl<T: Default> Default for CacheAligned<T> {
fn default() -> Self {
Self(T::default())
}
}
impl<T: Clone> Clone for CacheAligned<T> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
pub const DIRECT_IO_ALIGNMENT: usize = 4096;
#[must_use]
pub fn is_direct_io_aligned<T>(ptr: *const T) -> bool {
(ptr as usize).is_multiple_of(DIRECT_IO_ALIGNMENT)
}
#[cfg(not(target_arch = "wasm32"))]
pub struct AlignedBuffer {
ptr: *mut u8,
len: usize,
layout: std::alloc::Layout,
}
#[cfg(not(target_arch = "wasm32"))]
impl AlignedBuffer {
pub fn new(size: usize) -> Result<Self, TruenoError> {
use std::alloc::{alloc_zeroed, Layout};
let layout = Layout::from_size_align(size, DIRECT_IO_ALIGNMENT)
.map_err(|e| TruenoError::InvalidInput(format!("invalid alignment: {e}")))?;
let ptr = unsafe { alloc_zeroed(layout) };
if ptr.is_null() {
return Err(TruenoError::InvalidInput("allocation failed".into()));
}
Ok(Self { ptr, len: size, layout })
}
pub fn as_slice(&self) -> &[u8] {
unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
}
pub fn as_mut_slice(&mut self) -> &mut [u8] {
unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
}
pub fn as_ptr(&self) -> *const u8 {
self.ptr
}
pub fn as_mut_ptr(&mut self) -> *mut u8 {
self.ptr
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
}
#[cfg(not(target_arch = "wasm32"))]
impl Drop for AlignedBuffer {
fn drop(&mut self) {
unsafe {
std::alloc::dealloc(self.ptr, self.layout);
}
}
}
#[cfg(not(target_arch = "wasm32"))]
unsafe impl Send for AlignedBuffer {}
#[cfg(not(target_arch = "wasm32"))]
unsafe impl Sync for AlignedBuffer {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MemoryAdvice {
Sequential,
Random,
WillNeed,
DontNeed,
}
#[cfg(target_os = "linux")]
const MADV_SEQUENTIAL: i32 = 2;
#[cfg(target_os = "linux")]
const MADV_RANDOM: i32 = 1;
#[cfg(target_os = "linux")]
const MADV_WILLNEED: i32 = 3;
#[cfg(target_os = "linux")]
const MADV_DONTNEED: i32 = 4;
#[cfg(target_os = "linux")]
pub unsafe fn madvise_region(
addr: *mut u8,
len: usize,
advice: MemoryAdvice,
) -> std::io::Result<()> {
unsafe {
#[cfg(target_arch = "x86_64")]
const SYS_MADVISE: i64 = 28;
#[cfg(target_arch = "aarch64")]
const SYS_MADVISE: i64 = 233;
let advice_flag: i32 = match advice {
MemoryAdvice::Sequential => MADV_SEQUENTIAL,
MemoryAdvice::Random => MADV_RANDOM,
MemoryAdvice::WillNeed => MADV_WILLNEED,
MemoryAdvice::DontNeed => MADV_DONTNEED,
};
let ret: i64;
#[cfg(target_arch = "x86_64")]
{
core::arch::asm!(
"syscall",
inout("rax") SYS_MADVISE => ret,
in("rdi") addr as usize,
in("rsi") len,
in("rdx") advice_flag as i64,
out("rcx") _,
out("r11") _,
options(nostack)
);
}
#[cfg(target_arch = "aarch64")]
{
core::arch::asm!(
"svc 0",
inout("x8") SYS_MADVISE => _,
inout("x0") addr as usize => ret,
in("x1") len,
in("x2") advice_flag as i64,
options(nostack)
);
}
if ret < 0 {
return Err(std::io::Error::from_raw_os_error(-ret as i32));
}
Ok(())
}
}
#[cfg(not(target_os = "linux"))]
pub unsafe fn madvise_region(
_addr: *mut u8,
_len: usize,
_advice: MemoryAdvice,
) -> std::io::Result<()> {
Ok(()) }
#[cfg(target_os = "linux")]
pub unsafe fn prefetch_for_inference(addr: *mut u8, len: usize) -> std::io::Result<()> {
unsafe {
madvise_region(addr, len, MemoryAdvice::WillNeed)?;
madvise_region(addr, len, MemoryAdvice::Random)?;
Ok(())
}
}
#[cfg(not(target_os = "linux"))]
pub unsafe fn prefetch_for_inference(_addr: *mut u8, _len: usize) -> std::io::Result<()> {
Ok(()) }
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PrefetchLocality {
None = 0,
Low = 1,
Moderate = 2,
High = 3,
}
#[inline]
#[cfg(target_arch = "x86_64")]
pub unsafe fn prefetch_ptr<T>(ptr: *const T, locality: PrefetchLocality) {
unsafe {
use core::arch::x86_64::*;
match locality {
PrefetchLocality::None => _mm_prefetch(ptr as *const i8, _MM_HINT_NTA),
PrefetchLocality::Low => _mm_prefetch(ptr as *const i8, _MM_HINT_T2),
PrefetchLocality::Moderate => _mm_prefetch(ptr as *const i8, _MM_HINT_T1),
PrefetchLocality::High => _mm_prefetch(ptr as *const i8, _MM_HINT_T0),
}
}
}
#[inline]
#[cfg(target_arch = "aarch64")]
pub unsafe fn prefetch_ptr<T>(ptr: *const T, _locality: PrefetchLocality) {
core::arch::asm!(
"prfm pldl1keep, [{ptr}]",
ptr = in(reg) ptr,
options(nostack, preserves_flags)
);
}
#[inline]
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
pub unsafe fn prefetch_ptr<T>(_ptr: *const T, _locality: PrefetchLocality) {
}
#[inline]
pub fn prefetch_slice<T>(slice: &[T], locality: PrefetchLocality) {
let ptr = slice.as_ptr() as *const u8;
let len = std::mem::size_of_val(slice);
for offset in (0..len).step_by(CACHE_LINE_SIZE) {
unsafe {
prefetch_ptr(ptr.add(offset), locality);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_aligned_alignment() {
let aligned: CacheAligned<u64> = CacheAligned::new(42);
assert_eq!(std::mem::align_of_val(&aligned), 64);
}
#[test]
fn test_cache_aligned_value() {
let aligned = CacheAligned::new(42u64);
assert_eq!(*aligned.get(), 42);
}
#[test]
fn test_cache_aligned_get_mut() {
let mut aligned = CacheAligned::new(42u64);
*aligned.get_mut() = 100;
assert_eq!(*aligned.get(), 100);
}
#[test]
fn test_cache_aligned_into_inner() {
let aligned = CacheAligned::new(42u64);
assert_eq!(aligned.into_inner(), 42);
}
#[test]
fn test_cache_aligned_default() {
let aligned: CacheAligned<u64> = CacheAligned::default();
assert_eq!(*aligned.get(), 0);
}
#[test]
fn test_cache_aligned_clone() {
let aligned = CacheAligned::new(42u64);
let cloned = aligned.clone();
assert_eq!(*cloned.get(), 42);
}
#[test]
fn test_cache_line_size_f32() {
assert_eq!(CACHE_LINE_SIZE_F32, 16); }
#[test]
fn test_direct_io_alignment() {
assert_eq!(DIRECT_IO_ALIGNMENT, 4096);
}
#[test]
fn test_is_direct_io_aligned() {
let aligned_addr: usize = 4096 * 10;
let unaligned_addr: usize = 4096 * 10 + 1;
assert!(is_direct_io_aligned(aligned_addr as *const u8));
assert!(!is_direct_io_aligned(unaligned_addr as *const u8));
}
#[cfg(not(target_arch = "wasm32"))]
#[test]
fn test_aligned_buffer_creation() {
let buffer = AlignedBuffer::new(4096).unwrap();
assert_eq!(buffer.len(), 4096);
assert!(!buffer.is_empty());
}
#[cfg(not(target_arch = "wasm32"))]
#[test]
fn test_aligned_buffer_zeroed() {
let buffer = AlignedBuffer::new(1024).unwrap();
let slice = buffer.as_slice();
assert!(slice.iter().all(|&b| b == 0));
}
#[cfg(not(target_arch = "wasm32"))]
#[test]
fn test_aligned_buffer_write() {
let mut buffer = AlignedBuffer::new(1024).unwrap();
buffer.as_mut_slice()[0] = 42;
assert_eq!(buffer.as_slice()[0], 42);
}
#[test]
fn test_memory_advice_eq() {
assert_eq!(MemoryAdvice::Sequential, MemoryAdvice::Sequential);
assert_ne!(MemoryAdvice::Sequential, MemoryAdvice::Random);
}
#[test]
fn test_prefetch_locality_values() {
assert_eq!(PrefetchLocality::None as u8, 0);
assert_eq!(PrefetchLocality::Low as u8, 1);
assert_eq!(PrefetchLocality::Moderate as u8, 2);
assert_eq!(PrefetchLocality::High as u8, 3);
}
#[test]
fn test_prefetch_slice_empty() {
let empty: &[f32] = &[];
prefetch_slice(empty, PrefetchLocality::High);
}
#[test]
fn test_prefetch_slice_small() {
let data = [1.0f32; 8];
prefetch_slice(&data, PrefetchLocality::High);
}
#[test]
fn test_madvise_region_stub() {
unsafe {
let mut data = [0u8; 4096];
let _result = madvise_region(data.as_mut_ptr(), data.len(), MemoryAdvice::WillNeed);
#[cfg(not(target_os = "linux"))]
assert!(_result.is_ok());
}
}
#[test]
fn test_prefetch_for_inference_stub() {
unsafe {
let mut data = [0u8; 4096];
let _result = prefetch_for_inference(data.as_mut_ptr(), data.len());
#[cfg(not(target_os = "linux"))]
assert!(_result.is_ok());
}
}
}