Skip to main content

trueno/brick/
memory.rs

1//! Memory Management Primitives
2//!
3//! Cache line alignment, direct I/O buffers, memory advice, and prefetch utilities.
4
5use crate::error::TruenoError;
6
7// ----------------------------------------------------------------------------
8// LCP-06: Cache Line Padding
9// ----------------------------------------------------------------------------
10
11/// Cache line size (64 bytes on most modern CPUs).
12pub const CACHE_LINE_SIZE: usize = 64;
13
14/// Number of f32 values per cache line.
15pub const CACHE_LINE_SIZE_F32: usize = CACHE_LINE_SIZE / std::mem::size_of::<f32>();
16
17/// Cache-line aligned wrapper to prevent false sharing.
18///
19/// # Example
20/// ```rust
21/// use trueno::brick::CacheAligned;
22/// use std::sync::atomic::AtomicU64;
23///
24/// let aligned: CacheAligned<AtomicU64> = CacheAligned::new(AtomicU64::new(0));
25/// assert_eq!(std::mem::align_of_val(&aligned), 64);
26/// ```
27#[repr(align(64))]
28#[derive(Debug)]
29pub struct CacheAligned<T>(pub T);
30
31impl<T> CacheAligned<T> {
32    /// Create a new cache-aligned value.
33    pub const fn new(value: T) -> Self {
34        Self(value)
35    }
36
37    /// Get a reference to the inner value.
38    pub fn get(&self) -> &T {
39        &self.0
40    }
41
42    /// Get a mutable reference to the inner value.
43    pub fn get_mut(&mut self) -> &mut T {
44        &mut self.0
45    }
46
47    /// Consume the wrapper and return the inner value.
48    pub fn into_inner(self) -> T {
49        self.0
50    }
51}
52
53impl<T: Default> Default for CacheAligned<T> {
54    fn default() -> Self {
55        Self(T::default())
56    }
57}
58
59impl<T: Clone> Clone for CacheAligned<T> {
60    fn clone(&self) -> Self {
61        Self(self.0.clone())
62    }
63}
64
65// ----------------------------------------------------------------------------
66// LCP-02: Direct I/O Alignment
67// ----------------------------------------------------------------------------
68
69/// Memory alignment for direct I/O (4KB page aligned).
70pub const DIRECT_IO_ALIGNMENT: usize = 4096;
71
72/// Check if a pointer is aligned for direct I/O.
73#[must_use]
74pub fn is_direct_io_aligned<T>(ptr: *const T) -> bool {
75    (ptr as usize).is_multiple_of(DIRECT_IO_ALIGNMENT)
76}
77
78/// Aligned buffer for direct I/O operations.
79#[cfg(not(target_arch = "wasm32"))]
80pub struct AlignedBuffer {
81    ptr: *mut u8,
82    len: usize,
83    layout: std::alloc::Layout,
84}
85
86#[cfg(not(target_arch = "wasm32"))]
87impl AlignedBuffer {
88    /// Allocate a new aligned buffer.
89    ///
90    /// # Errors
91    /// Returns an error if allocation fails.
92    pub fn new(size: usize) -> Result<Self, TruenoError> {
93        use std::alloc::{alloc_zeroed, Layout};
94
95        let layout = Layout::from_size_align(size, DIRECT_IO_ALIGNMENT)
96            .map_err(|e| TruenoError::InvalidInput(format!("invalid alignment: {e}")))?;
97
98        // SAFETY: layout is valid, pointer was allocated with matching layout
99        let ptr = unsafe { alloc_zeroed(layout) };
100        if ptr.is_null() {
101            return Err(TruenoError::InvalidInput("allocation failed".into()));
102        }
103
104        Ok(Self { ptr, len: size, layout })
105    }
106
107    /// Get the buffer as a slice.
108    pub fn as_slice(&self) -> &[u8] {
109        // SAFETY: ptr is non-null and points to `self.len` bytes allocated in `new()`;
110        // the allocation lives for the lifetime of `self`.
111        unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
112    }
113
114    /// Get the buffer as a mutable slice.
115    pub fn as_mut_slice(&mut self) -> &mut [u8] {
116        // SAFETY: ptr is non-null and points to `self.len` bytes allocated in `new()`;
117        // `&mut self` guarantees exclusive access.
118        unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
119    }
120
121    /// Get the raw pointer.
122    pub fn as_ptr(&self) -> *const u8 {
123        self.ptr
124    }
125
126    /// Get the mutable raw pointer.
127    pub fn as_mut_ptr(&mut self) -> *mut u8 {
128        self.ptr
129    }
130
131    /// Get the buffer length.
132    pub fn len(&self) -> usize {
133        self.len
134    }
135
136    /// Check if the buffer is empty.
137    pub fn is_empty(&self) -> bool {
138        self.len == 0
139    }
140}
141
142#[cfg(not(target_arch = "wasm32"))]
143impl Drop for AlignedBuffer {
144    fn drop(&mut self) {
145        // SAFETY: self.ptr was allocated with std::alloc::alloc_zeroed using self.layout
146        // in AlignedBuffer::new(); dealloc uses the matching layout.
147        unsafe {
148            std::alloc::dealloc(self.ptr, self.layout);
149        }
150    }
151}
152
153#[cfg(not(target_arch = "wasm32"))]
154// SAFETY: type invariants ensure trait contract is upheld
155unsafe impl Send for AlignedBuffer {}
156
157#[cfg(not(target_arch = "wasm32"))]
158// SAFETY: type invariants ensure trait contract is upheld
159unsafe impl Sync for AlignedBuffer {}
160
161// ----------------------------------------------------------------------------
162// LCP-03: Memory Advice (madvise patterns)
163// ----------------------------------------------------------------------------
164
165/// Memory advice for mmap regions.
166#[derive(Debug, Clone, Copy, PartialEq, Eq)]
167pub enum MemoryAdvice {
168    /// Sequential access (enable readahead)
169    Sequential,
170    /// Random access (disable readahead)
171    Random,
172    /// Will need soon (prefetch)
173    WillNeed,
174    /// Don't need (can be paged out)
175    DontNeed,
176}
177
178// Linux madvise constants (from linux/mman.h)
179#[cfg(target_os = "linux")]
180const MADV_SEQUENTIAL: i32 = 2;
181#[cfg(target_os = "linux")]
182const MADV_RANDOM: i32 = 1;
183#[cfg(target_os = "linux")]
184const MADV_WILLNEED: i32 = 3;
185#[cfg(target_os = "linux")]
186const MADV_DONTNEED: i32 = 4;
187
188/// Apply memory advice to a region (Linux only).
189///
190/// # Safety
191/// The pointer must be valid and the length must not exceed the mapped region.
192#[cfg(target_os = "linux")]
193// SAFETY: Caller ensures pointer is valid and length does not exceed the mapped region
194pub unsafe fn madvise_region(
195    addr: *mut u8,
196    len: usize,
197    advice: MemoryAdvice,
198) -> std::io::Result<()> {
199    unsafe {
200        // madvise syscall number is 28 on x86_64
201        #[cfg(target_arch = "x86_64")]
202        const SYS_MADVISE: i64 = 28;
203        #[cfg(target_arch = "aarch64")]
204        const SYS_MADVISE: i64 = 233;
205
206        let advice_flag: i32 = match advice {
207            MemoryAdvice::Sequential => MADV_SEQUENTIAL,
208            MemoryAdvice::Random => MADV_RANDOM,
209            MemoryAdvice::WillNeed => MADV_WILLNEED,
210            MemoryAdvice::DontNeed => MADV_DONTNEED,
211        };
212
213        let ret: i64;
214        #[cfg(target_arch = "x86_64")]
215        {
216            core::arch::asm!(
217                "syscall",
218                inout("rax") SYS_MADVISE => ret,
219                in("rdi") addr as usize,
220                in("rsi") len,
221                in("rdx") advice_flag as i64,
222                out("rcx") _,
223                out("r11") _,
224                options(nostack)
225            );
226        }
227        #[cfg(target_arch = "aarch64")]
228        {
229            core::arch::asm!(
230                "svc 0",
231                inout("x8") SYS_MADVISE => _,
232                inout("x0") addr as usize => ret,
233                in("x1") len,
234                in("x2") advice_flag as i64,
235                options(nostack)
236            );
237        }
238
239        if ret < 0 {
240            return Err(std::io::Error::from_raw_os_error(-ret as i32));
241        }
242
243        Ok(())
244    }
245}
246
247/// Stub for non-Linux platforms.
248#[cfg(not(target_os = "linux"))]
249// SAFETY: No-op stub, no actual unsafe operations performed
250pub unsafe fn madvise_region(
251    _addr: *mut u8,
252    _len: usize,
253    _advice: MemoryAdvice,
254) -> std::io::Result<()> {
255    Ok(()) // No-op on non-Linux
256}
257
258/// Apply dual-level prefetch strategy (WILLNEED + RANDOM).
259///
260/// This is the llama.cpp pattern for model loading:
261/// 1. MADV_WILLNEED: Tell kernel to prefetch the data
262/// 2. MADV_RANDOM: Disable readahead (model access is random)
263///
264/// # Safety
265/// The pointer must be valid and the length must not exceed the mapped region.
266#[cfg(target_os = "linux")]
267// SAFETY: caller ensures preconditions are met for this unsafe function
268pub unsafe fn prefetch_for_inference(addr: *mut u8, len: usize) -> std::io::Result<()> {
269    unsafe {
270        // First: tell kernel we'll need this data
271        madvise_region(addr, len, MemoryAdvice::WillNeed)?;
272        // Second: hint random access pattern (disables readahead waste)
273        madvise_region(addr, len, MemoryAdvice::Random)?;
274        Ok(())
275    }
276}
277
278/// Stub for non-Linux platforms.
279#[cfg(not(target_os = "linux"))]
280// SAFETY: caller ensures preconditions are met for this unsafe function
281pub unsafe fn prefetch_for_inference(_addr: *mut u8, _len: usize) -> std::io::Result<()> {
282    Ok(()) // No-op on non-Linux
283}
284
285// ----------------------------------------------------------------------------
286// LCP-11: Prefetch with Locality Hints
287// ----------------------------------------------------------------------------
288
289/// Prefetch locality hints.
290#[derive(Debug, Clone, Copy, PartialEq, Eq)]
291pub enum PrefetchLocality {
292    /// No temporal locality (use once, don't pollute cache)
293    None = 0,
294    /// Low temporal locality (use a few times)
295    Low = 1,
296    /// Moderate temporal locality
297    Moderate = 2,
298    /// High temporal locality (keep in all cache levels)
299    High = 3,
300}
301
302/// Prefetch data into cache.
303///
304/// # Safety
305/// The pointer must be valid for reading.
306#[inline]
307#[cfg(target_arch = "x86_64")]
308// SAFETY: caller ensures preconditions are met for this unsafe function
309pub unsafe fn prefetch_ptr<T>(ptr: *const T, locality: PrefetchLocality) {
310    unsafe {
311        use core::arch::x86_64::*;
312        match locality {
313            PrefetchLocality::None => _mm_prefetch(ptr as *const i8, _MM_HINT_NTA),
314            PrefetchLocality::Low => _mm_prefetch(ptr as *const i8, _MM_HINT_T2),
315            PrefetchLocality::Moderate => _mm_prefetch(ptr as *const i8, _MM_HINT_T1),
316            PrefetchLocality::High => _mm_prefetch(ptr as *const i8, _MM_HINT_T0),
317        }
318    }
319}
320
321/// Prefetch data into cache (ARM64).
322#[inline]
323#[cfg(target_arch = "aarch64")]
324// SAFETY: caller ensures preconditions are met for this unsafe function
325pub unsafe fn prefetch_ptr<T>(ptr: *const T, _locality: PrefetchLocality) {
326    // ARM prefetch (PRFM instruction) - locality hints are limited
327    core::arch::asm!(
328        "prfm pldl1keep, [{ptr}]",
329        ptr = in(reg) ptr,
330        options(nostack, preserves_flags)
331    );
332}
333
334/// Fallback for other architectures.
335#[inline]
336#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
337// SAFETY: caller ensures preconditions are met for this unsafe function
338pub unsafe fn prefetch_ptr<T>(_ptr: *const T, _locality: PrefetchLocality) {
339    // No-op on unsupported architectures
340}
341
342/// Prefetch a slice of data.
343///
344/// Prefetches each cache line in the slice.
345#[inline]
346pub fn prefetch_slice<T>(slice: &[T], locality: PrefetchLocality) {
347    let ptr = slice.as_ptr() as *const u8;
348    let len = std::mem::size_of_val(slice);
349
350    for offset in (0..len).step_by(CACHE_LINE_SIZE) {
351        // SAFETY: ptr.add(offset) is bounded by the slice length; prefetch is
352        // a hint and does not dereference memory.
353        unsafe {
354            prefetch_ptr(ptr.add(offset), locality);
355        }
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362
363    #[test]
364    fn test_cache_aligned_alignment() {
365        let aligned: CacheAligned<u64> = CacheAligned::new(42);
366        assert_eq!(std::mem::align_of_val(&aligned), 64);
367    }
368
369    #[test]
370    fn test_cache_aligned_value() {
371        let aligned = CacheAligned::new(42u64);
372        assert_eq!(*aligned.get(), 42);
373    }
374
375    #[test]
376    fn test_cache_aligned_get_mut() {
377        let mut aligned = CacheAligned::new(42u64);
378        *aligned.get_mut() = 100;
379        assert_eq!(*aligned.get(), 100);
380    }
381
382    #[test]
383    fn test_cache_aligned_into_inner() {
384        let aligned = CacheAligned::new(42u64);
385        assert_eq!(aligned.into_inner(), 42);
386    }
387
388    #[test]
389    fn test_cache_aligned_default() {
390        let aligned: CacheAligned<u64> = CacheAligned::default();
391        assert_eq!(*aligned.get(), 0);
392    }
393
394    #[test]
395    fn test_cache_aligned_clone() {
396        let aligned = CacheAligned::new(42u64);
397        let cloned = aligned.clone();
398        assert_eq!(*cloned.get(), 42);
399    }
400
401    #[test]
402    fn test_cache_line_size_f32() {
403        assert_eq!(CACHE_LINE_SIZE_F32, 16); // 64 / 4 = 16
404    }
405
406    #[test]
407    fn test_direct_io_alignment() {
408        assert_eq!(DIRECT_IO_ALIGNMENT, 4096);
409    }
410
411    #[test]
412    fn test_is_direct_io_aligned() {
413        let aligned_addr: usize = 4096 * 10;
414        let unaligned_addr: usize = 4096 * 10 + 1;
415
416        assert!(is_direct_io_aligned(aligned_addr as *const u8));
417        assert!(!is_direct_io_aligned(unaligned_addr as *const u8));
418    }
419
420    #[cfg(not(target_arch = "wasm32"))]
421    #[test]
422    fn test_aligned_buffer_creation() {
423        let buffer = AlignedBuffer::new(4096).unwrap();
424        assert_eq!(buffer.len(), 4096);
425        assert!(!buffer.is_empty());
426    }
427
428    #[cfg(not(target_arch = "wasm32"))]
429    #[test]
430    fn test_aligned_buffer_zeroed() {
431        let buffer = AlignedBuffer::new(1024).unwrap();
432        let slice = buffer.as_slice();
433        assert!(slice.iter().all(|&b| b == 0));
434    }
435
436    #[cfg(not(target_arch = "wasm32"))]
437    #[test]
438    fn test_aligned_buffer_write() {
439        let mut buffer = AlignedBuffer::new(1024).unwrap();
440        buffer.as_mut_slice()[0] = 42;
441        assert_eq!(buffer.as_slice()[0], 42);
442    }
443
444    #[test]
445    fn test_memory_advice_eq() {
446        assert_eq!(MemoryAdvice::Sequential, MemoryAdvice::Sequential);
447        assert_ne!(MemoryAdvice::Sequential, MemoryAdvice::Random);
448    }
449
450    #[test]
451    fn test_prefetch_locality_values() {
452        assert_eq!(PrefetchLocality::None as u8, 0);
453        assert_eq!(PrefetchLocality::Low as u8, 1);
454        assert_eq!(PrefetchLocality::Moderate as u8, 2);
455        assert_eq!(PrefetchLocality::High as u8, 3);
456    }
457
458    #[test]
459    fn test_prefetch_slice_empty() {
460        let empty: &[f32] = &[];
461        prefetch_slice(empty, PrefetchLocality::High);
462        // Should not panic
463    }
464
465    #[test]
466    fn test_prefetch_slice_small() {
467        let data = [1.0f32; 8];
468        prefetch_slice(&data, PrefetchLocality::High);
469        // Should not panic
470    }
471
472    #[test]
473    fn test_madvise_region_stub() {
474        // On non-Linux, this is a no-op
475        // SAFETY: preconditions verified by caller
476        unsafe {
477            let mut data = [0u8; 4096];
478            let _result = madvise_region(data.as_mut_ptr(), data.len(), MemoryAdvice::WillNeed);
479            #[cfg(not(target_os = "linux"))]
480            assert!(_result.is_ok());
481        }
482    }
483
484    #[test]
485    fn test_prefetch_for_inference_stub() {
486        // SAFETY: preconditions verified by caller
487        unsafe {
488            let mut data = [0u8; 4096];
489            let _result = prefetch_for_inference(data.as_mut_ptr(), data.len());
490            #[cfg(not(target_os = "linux"))]
491            assert!(_result.is_ok());
492        }
493    }
494}