Skip to main content

ringkernel_core/
memory.rs

1//! GPU and host memory management abstractions.
2//!
3//! This module provides RAII wrappers for GPU memory, pinned host memory,
4//! and memory pools for efficient allocation.
5
6use std::alloc::{alloc, dealloc, Layout};
7use std::marker::PhantomData;
8use std::ptr::NonNull;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::Arc;
11
12use parking_lot::Mutex;
13
14use crate::error::{Result, RingKernelError};
15
16/// Trait for GPU buffer operations.
17pub trait GpuBuffer: Send + Sync {
18    /// Get buffer size in bytes.
19    fn size(&self) -> usize;
20
21    /// Get device pointer (as usize for FFI compatibility).
22    fn device_ptr(&self) -> usize;
23
24    /// Copy data from host to device.
25    fn copy_from_host(&self, data: &[u8]) -> Result<()>;
26
27    /// Copy data from device to host.
28    fn copy_to_host(&self, data: &mut [u8]) -> Result<()>;
29}
30
31/// Trait for device memory allocation.
32pub trait DeviceMemory: Send + Sync {
33    /// Allocate device memory.
34    fn allocate(&self, size: usize) -> Result<Box<dyn GpuBuffer>>;
35
36    /// Allocate device memory with alignment.
37    fn allocate_aligned(&self, size: usize, alignment: usize) -> Result<Box<dyn GpuBuffer>>;
38
39    /// Get total device memory.
40    fn total_memory(&self) -> usize;
41
42    /// Get free device memory.
43    fn free_memory(&self) -> usize;
44}
45
46/// Pinned (page-locked) host memory for efficient DMA transfers.
47///
48/// Pinned memory allows direct DMA transfers between host and device
49/// without intermediate copying, significantly improving transfer performance.
50pub struct PinnedMemory<T: Copy> {
51    ptr: NonNull<T>,
52    len: usize,
53    layout: Layout,
54    _marker: PhantomData<T>,
55}
56
57impl<T: Copy> PinnedMemory<T> {
58    /// Allocate pinned memory for `count` elements.
59    ///
60    /// # Safety
61    ///
62    /// The underlying memory is uninitialized. Caller must ensure
63    /// data is initialized before reading.
64    pub fn new(count: usize) -> Result<Self> {
65        if count == 0 {
66            return Err(RingKernelError::InvalidConfig(
67                "Cannot allocate zero-sized buffer".to_string(),
68            ));
69        }
70
71        let layout =
72            Layout::array::<T>(count).map_err(|_| RingKernelError::HostAllocationFailed {
73                size: count * std::mem::size_of::<T>(),
74            })?;
75
76        // In production, this would use platform-specific pinned allocation
77        // (e.g., cuMemAllocHost for CUDA, or mlock for general case)
78        let ptr = unsafe { alloc(layout) };
79
80        if ptr.is_null() {
81            return Err(RingKernelError::HostAllocationFailed {
82                size: layout.size(),
83            });
84        }
85
86        Ok(Self {
87            // SAFETY: ptr is guaranteed non-null by the is_null() check above
88            ptr: NonNull::new(ptr as *mut T).expect("ptr verified non-null above"),
89            len: count,
90            layout,
91            _marker: PhantomData,
92        })
93    }
94
95    /// Create pinned memory from a slice, copying the data.
96    pub fn from_slice(data: &[T]) -> Result<Self> {
97        let mut mem = Self::new(data.len())?;
98        mem.as_mut_slice().copy_from_slice(data);
99        Ok(mem)
100    }
101
102    /// Get slice reference.
103    pub fn as_slice(&self) -> &[T] {
104        unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
105    }
106
107    /// Get mutable slice reference.
108    pub fn as_mut_slice(&mut self) -> &mut [T] {
109        unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
110    }
111
112    /// Get raw pointer.
113    pub fn as_ptr(&self) -> *const T {
114        self.ptr.as_ptr()
115    }
116
117    /// Get mutable raw pointer.
118    pub fn as_mut_ptr(&mut self) -> *mut T {
119        self.ptr.as_ptr()
120    }
121
122    /// Get number of elements.
123    pub fn len(&self) -> usize {
124        self.len
125    }
126
127    /// Check if empty.
128    pub fn is_empty(&self) -> bool {
129        self.len == 0
130    }
131
132    /// Get size in bytes.
133    pub fn size_bytes(&self) -> usize {
134        self.len * std::mem::size_of::<T>()
135    }
136}
137
138impl<T: Copy> Drop for PinnedMemory<T> {
139    fn drop(&mut self) {
140        // In production, this would use platform-specific deallocation
141        unsafe {
142            dealloc(self.ptr.as_ptr() as *mut u8, self.layout);
143        }
144    }
145}
146
147// SAFETY: PinnedMemory can be sent between threads
148unsafe impl<T: Copy + Send> Send for PinnedMemory<T> {}
149unsafe impl<T: Copy + Sync> Sync for PinnedMemory<T> {}
150
151/// Memory pool for efficient allocation/deallocation.
152///
153/// Memory pools amortize allocation costs by maintaining a free list
154/// of pre-allocated buffers.
155pub struct MemoryPool {
156    /// Pool name for debugging.
157    name: String,
158    /// Buffer size for this pool.
159    buffer_size: usize,
160    /// Maximum number of buffers to pool.
161    max_buffers: usize,
162    /// Free list of buffers.
163    free_list: Mutex<Vec<Vec<u8>>>,
164    /// Statistics: total allocations.
165    total_allocations: AtomicUsize,
166    /// Statistics: cache hits.
167    cache_hits: AtomicUsize,
168    /// Statistics: current pool size.
169    pool_size: AtomicUsize,
170}
171
172impl MemoryPool {
173    /// Create a new memory pool.
174    pub fn new(name: impl Into<String>, buffer_size: usize, max_buffers: usize) -> Self {
175        Self {
176            name: name.into(),
177            buffer_size,
178            max_buffers,
179            free_list: Mutex::new(Vec::with_capacity(max_buffers)),
180            total_allocations: AtomicUsize::new(0),
181            cache_hits: AtomicUsize::new(0),
182            pool_size: AtomicUsize::new(0),
183        }
184    }
185
186    /// Allocate a buffer from the pool.
187    pub fn allocate(&self) -> PooledBuffer<'_> {
188        self.total_allocations.fetch_add(1, Ordering::Relaxed);
189
190        let buffer = {
191            let mut free = self.free_list.lock();
192            if let Some(buf) = free.pop() {
193                self.cache_hits.fetch_add(1, Ordering::Relaxed);
194                self.pool_size.fetch_sub(1, Ordering::Relaxed);
195                buf
196            } else {
197                vec![0u8; self.buffer_size]
198            }
199        };
200
201        PooledBuffer {
202            buffer: Some(buffer),
203            pool: self,
204        }
205    }
206
207    /// Return a buffer to the pool.
208    fn return_buffer(&self, mut buffer: Vec<u8>) {
209        let mut free = self.free_list.lock();
210        if free.len() < self.max_buffers {
211            buffer.clear();
212            buffer.resize(self.buffer_size, 0);
213            free.push(buffer);
214            self.pool_size.fetch_add(1, Ordering::Relaxed);
215        }
216        // If pool is full, buffer is dropped
217    }
218
219    /// Get pool name.
220    pub fn name(&self) -> &str {
221        &self.name
222    }
223
224    /// Get buffer size.
225    pub fn buffer_size(&self) -> usize {
226        self.buffer_size
227    }
228
229    /// Get current pool size.
230    pub fn current_size(&self) -> usize {
231        self.pool_size.load(Ordering::Relaxed)
232    }
233
234    /// Get cache hit rate.
235    pub fn hit_rate(&self) -> f64 {
236        let total = self.total_allocations.load(Ordering::Relaxed);
237        let hits = self.cache_hits.load(Ordering::Relaxed);
238        if total == 0 {
239            0.0
240        } else {
241            hits as f64 / total as f64
242        }
243    }
244
245    /// Pre-allocate buffers.
246    pub fn preallocate(&self, count: usize) {
247        let count = count.min(self.max_buffers);
248        let mut free = self.free_list.lock();
249        for _ in free.len()..count {
250            free.push(vec![0u8; self.buffer_size]);
251            self.pool_size.fetch_add(1, Ordering::Relaxed);
252        }
253    }
254}
255
256/// A buffer from a memory pool.
257///
258/// When dropped, the buffer is returned to the pool for reuse.
259pub struct PooledBuffer<'a> {
260    buffer: Option<Vec<u8>>,
261    pool: &'a MemoryPool,
262}
263
264impl<'a> PooledBuffer<'a> {
265    /// Get slice reference.
266    pub fn as_slice(&self) -> &[u8] {
267        self.buffer.as_deref().unwrap_or(&[])
268    }
269
270    /// Get mutable slice reference.
271    pub fn as_mut_slice(&mut self) -> &mut [u8] {
272        self.buffer.as_deref_mut().unwrap_or(&mut [])
273    }
274
275    /// Get buffer length.
276    pub fn len(&self) -> usize {
277        self.buffer.as_ref().map(|b| b.len()).unwrap_or(0)
278    }
279
280    /// Check if empty.
281    pub fn is_empty(&self) -> bool {
282        self.len() == 0
283    }
284}
285
286impl<'a> Drop for PooledBuffer<'a> {
287    fn drop(&mut self) {
288        if let Some(buffer) = self.buffer.take() {
289            self.pool.return_buffer(buffer);
290        }
291    }
292}
293
294impl<'a> std::ops::Deref for PooledBuffer<'a> {
295    type Target = [u8];
296
297    fn deref(&self) -> &Self::Target {
298        self.as_slice()
299    }
300}
301
302impl<'a> std::ops::DerefMut for PooledBuffer<'a> {
303    fn deref_mut(&mut self) -> &mut Self::Target {
304        self.as_mut_slice()
305    }
306}
307
308/// Shared memory pool that can be cloned.
309pub type SharedMemoryPool = Arc<MemoryPool>;
310
311/// Create a shared memory pool.
312pub fn create_pool(
313    name: impl Into<String>,
314    buffer_size: usize,
315    max_buffers: usize,
316) -> SharedMemoryPool {
317    Arc::new(MemoryPool::new(name, buffer_size, max_buffers))
318}
319
320// ============================================================================
321// Size-Stratified Memory Pool
322// ============================================================================
323
324/// Size bucket for stratified pooling.
325///
326/// Provides predefined size classes for efficient multi-size pooling.
327/// Allocations are rounded up to the smallest bucket that fits.
328#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
329pub enum SizeBucket {
330    /// Tiny buffers (256 bytes) - metadata, small messages.
331    Tiny,
332    /// Small buffers (1 KB) - typical message payloads.
333    Small,
334    /// Medium buffers (4 KB) - page-sized allocations.
335    #[default]
336    Medium,
337    /// Large buffers (16 KB) - batch operations.
338    Large,
339    /// Huge buffers (64 KB) - large transfers.
340    Huge,
341}
342
343impl SizeBucket {
344    /// All bucket variants in order from smallest to largest.
345    pub const ALL: [SizeBucket; 5] = [
346        SizeBucket::Tiny,
347        SizeBucket::Small,
348        SizeBucket::Medium,
349        SizeBucket::Large,
350        SizeBucket::Huge,
351    ];
352
353    /// Get the size in bytes for this bucket.
354    pub fn size(&self) -> usize {
355        match self {
356            Self::Tiny => 256,
357            Self::Small => 1024,
358            Self::Medium => 4096,
359            Self::Large => 16384,
360            Self::Huge => 65536,
361        }
362    }
363
364    /// Find the smallest bucket that fits the requested size.
365    ///
366    /// Returns `Huge` for any size larger than 16KB.
367    pub fn for_size(requested: usize) -> Self {
368        if requested <= 256 {
369            Self::Tiny
370        } else if requested <= 1024 {
371            Self::Small
372        } else if requested <= 4096 {
373            Self::Medium
374        } else if requested <= 16384 {
375            Self::Large
376        } else {
377            Self::Huge
378        }
379    }
380
381    /// Get the next larger bucket, or self if already at largest.
382    pub fn upgrade(&self) -> Self {
383        match self {
384            Self::Tiny => Self::Small,
385            Self::Small => Self::Medium,
386            Self::Medium => Self::Large,
387            Self::Large => Self::Huge,
388            Self::Huge => Self::Huge,
389        }
390    }
391
392    /// Get the next smaller bucket, or self if already at smallest.
393    pub fn downgrade(&self) -> Self {
394        match self {
395            Self::Tiny => Self::Tiny,
396            Self::Small => Self::Tiny,
397            Self::Medium => Self::Small,
398            Self::Large => Self::Medium,
399            Self::Huge => Self::Large,
400        }
401    }
402}
403
404impl std::fmt::Display for SizeBucket {
405    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
406        match self {
407            Self::Tiny => write!(f, "Tiny(256B)"),
408            Self::Small => write!(f, "Small(1KB)"),
409            Self::Medium => write!(f, "Medium(4KB)"),
410            Self::Large => write!(f, "Large(16KB)"),
411            Self::Huge => write!(f, "Huge(64KB)"),
412        }
413    }
414}
415
416/// Statistics for a stratified memory pool.
417#[derive(Debug, Clone, Default)]
418pub struct StratifiedPoolStats {
419    /// Total allocations across all buckets.
420    pub total_allocations: usize,
421    /// Total cache hits across all buckets.
422    pub total_hits: usize,
423    /// Allocations per bucket.
424    pub allocations_per_bucket: std::collections::HashMap<SizeBucket, usize>,
425    /// Hits per bucket.
426    pub hits_per_bucket: std::collections::HashMap<SizeBucket, usize>,
427}
428
429impl StratifiedPoolStats {
430    /// Calculate overall hit rate.
431    pub fn hit_rate(&self) -> f64 {
432        if self.total_allocations == 0 {
433            0.0
434        } else {
435            self.total_hits as f64 / self.total_allocations as f64
436        }
437    }
438
439    /// Get hit rate for a specific bucket.
440    pub fn bucket_hit_rate(&self, bucket: SizeBucket) -> f64 {
441        let allocs = self
442            .allocations_per_bucket
443            .get(&bucket)
444            .copied()
445            .unwrap_or(0);
446        let hits = self.hits_per_bucket.get(&bucket).copied().unwrap_or(0);
447        if allocs == 0 {
448            0.0
449        } else {
450            hits as f64 / allocs as f64
451        }
452    }
453}
454
455/// Multi-size memory pool with automatic bucket selection.
456///
457/// Instead of having a single buffer size, this pool maintains separate
458/// pools for different size classes. Allocations are rounded up to the
459/// smallest bucket that fits.
460///
461/// # Example
462///
463/// ```ignore
464/// use ringkernel_core::memory::{StratifiedMemoryPool, SizeBucket};
465///
466/// let pool = StratifiedMemoryPool::new("my_pool");
467///
468/// // Allocate various sizes - each goes to appropriate bucket
469/// let tiny_buf = pool.allocate(100);   // Uses Tiny bucket (256B)
470/// let medium_buf = pool.allocate(2000); // Uses Medium bucket (4KB)
471///
472/// // Check statistics
473/// let stats = pool.stats();
474/// println!("Hit rate: {:.1}%", stats.hit_rate() * 100.0);
475/// ```
476pub struct StratifiedMemoryPool {
477    name: String,
478    buckets: std::collections::HashMap<SizeBucket, MemoryPool>,
479    max_buffers_per_bucket: usize,
480    stats: Mutex<StratifiedPoolStats>,
481}
482
483impl StratifiedMemoryPool {
484    /// Create a new stratified pool with default settings.
485    ///
486    /// Creates pools for all bucket sizes with 16 buffers per bucket.
487    pub fn new(name: impl Into<String>) -> Self {
488        Self::with_capacity(name, 16)
489    }
490
491    /// Create a pool with specified max buffers per bucket.
492    pub fn with_capacity(name: impl Into<String>, max_buffers_per_bucket: usize) -> Self {
493        let name = name.into();
494        let mut buckets = std::collections::HashMap::new();
495
496        for bucket in SizeBucket::ALL {
497            let pool_name = format!("{}_{}", name, bucket);
498            buckets.insert(
499                bucket,
500                MemoryPool::new(pool_name, bucket.size(), max_buffers_per_bucket),
501            );
502        }
503
504        Self {
505            name,
506            buckets,
507            max_buffers_per_bucket,
508            stats: Mutex::new(StratifiedPoolStats::default()),
509        }
510    }
511
512    /// Allocate a buffer of at least the requested size.
513    ///
514    /// The buffer may be larger than requested (rounded up to bucket size).
515    pub fn allocate(&self, size: usize) -> StratifiedBuffer<'_> {
516        let bucket = SizeBucket::for_size(size);
517        self.allocate_bucket(bucket)
518    }
519
520    /// Allocate from a specific bucket.
521    pub fn allocate_bucket(&self, bucket: SizeBucket) -> StratifiedBuffer<'_> {
522        let pool = self
523            .buckets
524            .get(&bucket)
525            .expect("all SizeBucket variants are inserted in new()");
526
527        // Track stats before allocation to capture hit
528        let was_cached = pool.current_size() > 0;
529        let buffer = pool.allocate();
530
531        // Update stats
532        {
533            let mut stats = self.stats.lock();
534            stats.total_allocations += 1;
535            *stats.allocations_per_bucket.entry(bucket).or_insert(0) += 1;
536            if was_cached {
537                stats.total_hits += 1;
538                *stats.hits_per_bucket.entry(bucket).or_insert(0) += 1;
539            }
540        }
541
542        StratifiedBuffer {
543            inner: buffer,
544            bucket,
545            pool: self,
546        }
547    }
548
549    /// Get pool name.
550    pub fn name(&self) -> &str {
551        &self.name
552    }
553
554    /// Get max buffers per bucket.
555    pub fn max_buffers_per_bucket(&self) -> usize {
556        self.max_buffers_per_bucket
557    }
558
559    /// Get current size of a specific bucket pool.
560    pub fn bucket_size(&self, bucket: SizeBucket) -> usize {
561        self.buckets
562            .get(&bucket)
563            .map(|p| p.current_size())
564            .unwrap_or(0)
565    }
566
567    /// Get total buffers currently pooled across all buckets.
568    pub fn total_pooled(&self) -> usize {
569        self.buckets.values().map(|p| p.current_size()).sum()
570    }
571
572    /// Get statistics snapshot.
573    pub fn stats(&self) -> StratifiedPoolStats {
574        self.stats.lock().clone()
575    }
576
577    /// Pre-allocate buffers for a specific bucket.
578    pub fn preallocate(&self, bucket: SizeBucket, count: usize) {
579        if let Some(pool) = self.buckets.get(&bucket) {
580            pool.preallocate(count);
581        }
582    }
583
584    /// Pre-allocate buffers for all buckets.
585    pub fn preallocate_all(&self, count_per_bucket: usize) {
586        for bucket in SizeBucket::ALL {
587            self.preallocate(bucket, count_per_bucket);
588        }
589    }
590
591    /// Shrink all pools to target utilization.
592    ///
593    /// Removes excess pooled buffers to free memory.
594    pub fn shrink_to(&self, target_per_bucket: usize) {
595        for pool in self.buckets.values() {
596            let mut free_list = pool.free_list.lock();
597            while free_list.len() > target_per_bucket {
598                free_list.pop();
599                pool.pool_size.fetch_sub(1, Ordering::Relaxed);
600            }
601        }
602    }
603}
604
605/// A buffer from a stratified memory pool.
606///
607/// Tracks which bucket it came from for proper return.
608pub struct StratifiedBuffer<'a> {
609    inner: PooledBuffer<'a>,
610    bucket: SizeBucket,
611    #[allow(dead_code)]
612    pool: &'a StratifiedMemoryPool,
613}
614
615impl<'a> StratifiedBuffer<'a> {
616    /// Get the size bucket this buffer was allocated from.
617    pub fn bucket(&self) -> SizeBucket {
618        self.bucket
619    }
620
621    /// Get the actual capacity (bucket size, may be larger than requested).
622    pub fn capacity(&self) -> usize {
623        self.bucket.size()
624    }
625
626    /// Get slice reference.
627    pub fn as_slice(&self) -> &[u8] {
628        self.inner.as_slice()
629    }
630
631    /// Get mutable slice reference.
632    pub fn as_mut_slice(&mut self) -> &mut [u8] {
633        self.inner.as_mut_slice()
634    }
635
636    /// Get buffer length.
637    pub fn len(&self) -> usize {
638        self.inner.len()
639    }
640
641    /// Check if empty.
642    pub fn is_empty(&self) -> bool {
643        self.inner.is_empty()
644    }
645}
646
647impl<'a> std::ops::Deref for StratifiedBuffer<'a> {
648    type Target = [u8];
649
650    fn deref(&self) -> &Self::Target {
651        self.as_slice()
652    }
653}
654
655impl<'a> std::ops::DerefMut for StratifiedBuffer<'a> {
656    fn deref_mut(&mut self) -> &mut Self::Target {
657        self.as_mut_slice()
658    }
659}
660
661/// Shared stratified memory pool.
662pub type SharedStratifiedPool = Arc<StratifiedMemoryPool>;
663
664/// Create a shared stratified memory pool.
665pub fn create_stratified_pool(name: impl Into<String>) -> SharedStratifiedPool {
666    Arc::new(StratifiedMemoryPool::new(name))
667}
668
669/// Create a shared stratified memory pool with custom capacity.
670pub fn create_stratified_pool_with_capacity(
671    name: impl Into<String>,
672    max_buffers_per_bucket: usize,
673) -> SharedStratifiedPool {
674    Arc::new(StratifiedMemoryPool::with_capacity(
675        name,
676        max_buffers_per_bucket,
677    ))
678}
679
680// ============================================================================
681// Memory Pressure Reactions
682// ============================================================================
683
684use crate::observability::MemoryPressureLevel;
685
686/// Callback type for memory pressure changes.
687pub type PressureCallback = Box<dyn Fn(MemoryPressureLevel) + Send + Sync>;
688
689/// Reaction to memory pressure events.
690///
691/// Pools can be configured to react to memory pressure by shrinking
692/// or invoking custom callbacks.
693pub enum PressureReaction {
694    /// No automatic reaction to pressure.
695    None,
696    /// Automatically shrink pool to target utilization.
697    ///
698    /// The `target_utilization` is a fraction (0.0 to 1.0) of the max
699    /// pool size to retain when under pressure.
700    Shrink {
701        /// Target utilization as fraction of max capacity.
702        target_utilization: f64,
703    },
704    /// Invoke a custom callback on pressure change.
705    Callback(PressureCallback),
706}
707
708impl std::fmt::Debug for PressureReaction {
709    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
710        match self {
711            Self::None => write!(f, "PressureReaction::None"),
712            Self::Shrink { target_utilization } => {
713                write!(
714                    f,
715                    "PressureReaction::Shrink {{ target_utilization: {} }}",
716                    target_utilization
717                )
718            }
719            Self::Callback(_) => write!(f, "PressureReaction::Callback(<fn>)"),
720        }
721    }
722}
723
724/// Memory pressure handler for stratified pools.
725///
726/// Monitors memory pressure levels and triggers configured reactions.
727pub struct PressureHandler {
728    /// Configured reaction.
729    reaction: PressureReaction,
730    /// Current pressure level.
731    current_level: Mutex<MemoryPressureLevel>,
732}
733
734impl PressureHandler {
735    /// Create a new pressure handler with the specified reaction.
736    pub fn new(reaction: PressureReaction) -> Self {
737        Self {
738            reaction,
739            current_level: Mutex::new(MemoryPressureLevel::Normal),
740        }
741    }
742
743    /// Create a handler with no reaction.
744    pub fn no_reaction() -> Self {
745        Self::new(PressureReaction::None)
746    }
747
748    /// Create a handler that shrinks to target utilization.
749    pub fn shrink_to(target_utilization: f64) -> Self {
750        Self::new(PressureReaction::Shrink {
751            target_utilization: target_utilization.clamp(0.0, 1.0),
752        })
753    }
754
755    /// Create a handler with a custom callback.
756    pub fn with_callback<F>(callback: F) -> Self
757    where
758        F: Fn(MemoryPressureLevel) + Send + Sync + 'static,
759    {
760        Self::new(PressureReaction::Callback(Box::new(callback)))
761    }
762
763    /// Get the current pressure level.
764    pub fn current_level(&self) -> MemoryPressureLevel {
765        *self.current_level.lock()
766    }
767
768    /// Handle a pressure level change.
769    ///
770    /// Returns the number of buffers to retain per bucket (if shrinking).
771    pub fn on_pressure_change(
772        &self,
773        new_level: MemoryPressureLevel,
774        max_per_bucket: usize,
775    ) -> Option<usize> {
776        let old_level = {
777            let mut current = self.current_level.lock();
778            let old = *current;
779            *current = new_level;
780            old
781        };
782
783        // Only react if pressure increased
784        if !Self::is_higher_pressure(new_level, old_level) {
785            return None;
786        }
787
788        match &self.reaction {
789            PressureReaction::None => None,
790            PressureReaction::Shrink { target_utilization } => {
791                // Calculate target based on pressure level
792                let pressure_factor = Self::pressure_severity(new_level);
793                let adjusted_target = target_utilization * (1.0 - pressure_factor);
794                let target_count = ((max_per_bucket as f64) * adjusted_target) as usize;
795                Some(target_count.max(1)) // Keep at least 1
796            }
797            PressureReaction::Callback(callback) => {
798                callback(new_level);
799                None
800            }
801        }
802    }
803
804    /// Check if new pressure level is higher than old.
805    fn is_higher_pressure(new: MemoryPressureLevel, old: MemoryPressureLevel) -> bool {
806        Self::pressure_ordinal(new) > Self::pressure_ordinal(old)
807    }
808
809    /// Get ordinal value for pressure level comparison.
810    fn pressure_ordinal(level: MemoryPressureLevel) -> u8 {
811        match level {
812            MemoryPressureLevel::Normal => 0,
813            MemoryPressureLevel::Elevated => 1,
814            MemoryPressureLevel::Warning => 2,
815            MemoryPressureLevel::Critical => 3,
816            MemoryPressureLevel::OutOfMemory => 4,
817        }
818    }
819
820    /// Get severity factor (0.0 to 1.0) for pressure level.
821    fn pressure_severity(level: MemoryPressureLevel) -> f64 {
822        match level {
823            MemoryPressureLevel::Normal => 0.0,
824            MemoryPressureLevel::Elevated => 0.2,
825            MemoryPressureLevel::Warning => 0.5,
826            MemoryPressureLevel::Critical => 0.8,
827            MemoryPressureLevel::OutOfMemory => 1.0,
828        }
829    }
830}
831
832impl std::fmt::Debug for PressureHandler {
833    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
834        f.debug_struct("PressureHandler")
835            .field("reaction", &self.reaction)
836            .field("current_level", &self.current_level())
837            .finish()
838    }
839}
840
841/// Extension trait for pressure-aware memory pools.
842pub trait PressureAwarePool {
843    /// Handle a memory pressure change event.
844    ///
845    /// Returns true if the pool took action (e.g., shrunk).
846    fn handle_pressure(&self, level: MemoryPressureLevel) -> bool;
847
848    /// Get current pressure level.
849    fn pressure_level(&self) -> MemoryPressureLevel;
850}
851
852/// Alignment utilities.
853pub mod align {
854    /// Cache line size (64 bytes on most modern CPUs).
855    pub const CACHE_LINE_SIZE: usize = 64;
856
857    /// GPU cache line size (128 bytes on many GPUs).
858    pub const GPU_CACHE_LINE_SIZE: usize = 128;
859
860    /// Align a value up to the next multiple of alignment.
861    #[inline]
862    pub const fn align_up(value: usize, alignment: usize) -> usize {
863        let mask = alignment - 1;
864        (value + mask) & !mask
865    }
866
867    /// Align a value down to the previous multiple of alignment.
868    #[inline]
869    pub const fn align_down(value: usize, alignment: usize) -> usize {
870        let mask = alignment - 1;
871        value & !mask
872    }
873
874    /// Check if a value is aligned.
875    #[inline]
876    pub const fn is_aligned(value: usize, alignment: usize) -> bool {
877        value & (alignment - 1) == 0
878    }
879
880    /// Get required padding for alignment.
881    #[inline]
882    pub const fn padding_for(offset: usize, alignment: usize) -> usize {
883        let misalignment = offset & (alignment - 1);
884        if misalignment == 0 {
885            0
886        } else {
887            alignment - misalignment
888        }
889    }
890}
891
892#[cfg(test)]
893mod tests {
894    use super::*;
895
896    #[test]
897    fn test_pinned_memory() {
898        let mut mem = PinnedMemory::<f32>::new(1024).unwrap();
899        assert_eq!(mem.len(), 1024);
900        assert_eq!(mem.size_bytes(), 1024 * 4);
901
902        // Write some data
903        let slice = mem.as_mut_slice();
904        for (i, v) in slice.iter_mut().enumerate() {
905            *v = i as f32;
906        }
907
908        // Read back
909        assert_eq!(mem.as_slice()[42], 42.0);
910    }
911
912    #[test]
913    fn test_pinned_memory_from_slice() {
914        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
915        let mem = PinnedMemory::from_slice(&data).unwrap();
916        assert_eq!(mem.as_slice(), &data[..]);
917    }
918
919    #[test]
920    fn test_memory_pool() {
921        let pool = MemoryPool::new("test", 1024, 10);
922
923        // First allocation should be fresh
924        let buf1 = pool.allocate();
925        assert_eq!(buf1.len(), 1024);
926        drop(buf1);
927
928        // Second allocation should be cached
929        let _buf2 = pool.allocate();
930        assert_eq!(pool.hit_rate(), 0.5); // 1 hit out of 2 allocations
931    }
932
933    #[test]
934    fn test_pool_preallocate() {
935        let pool = MemoryPool::new("test", 1024, 10);
936        pool.preallocate(5);
937        assert_eq!(pool.current_size(), 5);
938
939        // All allocations should hit cache
940        for _ in 0..5 {
941            let _ = pool.allocate();
942        }
943        assert_eq!(pool.hit_rate(), 1.0);
944    }
945
946    #[test]
947    fn test_align_up() {
948        use align::*;
949
950        assert_eq!(align_up(0, 64), 0);
951        assert_eq!(align_up(1, 64), 64);
952        assert_eq!(align_up(64, 64), 64);
953        assert_eq!(align_up(65, 64), 128);
954    }
955
956    #[test]
957    fn test_is_aligned() {
958        use align::*;
959
960        assert!(is_aligned(0, 64));
961        assert!(is_aligned(64, 64));
962        assert!(is_aligned(128, 64));
963        assert!(!is_aligned(1, 64));
964        assert!(!is_aligned(63, 64));
965    }
966
967    #[test]
968    fn test_padding_for() {
969        use align::*;
970
971        assert_eq!(padding_for(0, 64), 0);
972        assert_eq!(padding_for(1, 64), 63);
973        assert_eq!(padding_for(63, 64), 1);
974        assert_eq!(padding_for(64, 64), 0);
975    }
976
977    // ========================================================================
978    // Size-Stratified Pool Tests
979    // ========================================================================
980
981    #[test]
982    fn test_size_bucket_sizes() {
983        assert_eq!(SizeBucket::Tiny.size(), 256);
984        assert_eq!(SizeBucket::Small.size(), 1024);
985        assert_eq!(SizeBucket::Medium.size(), 4096);
986        assert_eq!(SizeBucket::Large.size(), 16384);
987        assert_eq!(SizeBucket::Huge.size(), 65536);
988    }
989
990    #[test]
991    fn test_size_bucket_selection() {
992        // Exact boundaries
993        assert_eq!(SizeBucket::for_size(0), SizeBucket::Tiny);
994        assert_eq!(SizeBucket::for_size(256), SizeBucket::Tiny);
995        assert_eq!(SizeBucket::for_size(257), SizeBucket::Small);
996        assert_eq!(SizeBucket::for_size(1024), SizeBucket::Small);
997        assert_eq!(SizeBucket::for_size(1025), SizeBucket::Medium);
998        assert_eq!(SizeBucket::for_size(4096), SizeBucket::Medium);
999        assert_eq!(SizeBucket::for_size(4097), SizeBucket::Large);
1000        assert_eq!(SizeBucket::for_size(16384), SizeBucket::Large);
1001        assert_eq!(SizeBucket::for_size(16385), SizeBucket::Huge);
1002        assert_eq!(SizeBucket::for_size(100000), SizeBucket::Huge);
1003    }
1004
1005    #[test]
1006    fn test_size_bucket_upgrade_downgrade() {
1007        assert_eq!(SizeBucket::Tiny.upgrade(), SizeBucket::Small);
1008        assert_eq!(SizeBucket::Small.upgrade(), SizeBucket::Medium);
1009        assert_eq!(SizeBucket::Medium.upgrade(), SizeBucket::Large);
1010        assert_eq!(SizeBucket::Large.upgrade(), SizeBucket::Huge);
1011        assert_eq!(SizeBucket::Huge.upgrade(), SizeBucket::Huge); // Max
1012
1013        assert_eq!(SizeBucket::Tiny.downgrade(), SizeBucket::Tiny); // Min
1014        assert_eq!(SizeBucket::Small.downgrade(), SizeBucket::Tiny);
1015        assert_eq!(SizeBucket::Medium.downgrade(), SizeBucket::Small);
1016        assert_eq!(SizeBucket::Large.downgrade(), SizeBucket::Medium);
1017        assert_eq!(SizeBucket::Huge.downgrade(), SizeBucket::Large);
1018    }
1019
1020    #[test]
1021    fn test_stratified_pool_allocation() {
1022        let pool = StratifiedMemoryPool::new("test");
1023
1024        // Allocate different sizes
1025        let buf1 = pool.allocate(100); // Tiny
1026        let buf2 = pool.allocate(500); // Small
1027        let buf3 = pool.allocate(2000); // Medium
1028
1029        assert_eq!(buf1.bucket(), SizeBucket::Tiny);
1030        assert_eq!(buf2.bucket(), SizeBucket::Small);
1031        assert_eq!(buf3.bucket(), SizeBucket::Medium);
1032
1033        // Buffers have full bucket capacity
1034        assert_eq!(buf1.capacity(), 256);
1035        assert_eq!(buf2.capacity(), 1024);
1036        assert_eq!(buf3.capacity(), 4096);
1037    }
1038
1039    #[test]
1040    fn test_stratified_pool_reuse() {
1041        let pool = StratifiedMemoryPool::new("test");
1042
1043        // First allocation - fresh
1044        {
1045            let _buf = pool.allocate(100);
1046        }
1047        // Buffer returned to pool
1048
1049        // Second allocation - should reuse
1050        {
1051            let _buf = pool.allocate(100);
1052        }
1053
1054        let stats = pool.stats();
1055        assert_eq!(stats.total_allocations, 2);
1056        assert_eq!(stats.total_hits, 1);
1057        assert!((stats.hit_rate() - 0.5).abs() < 0.001);
1058    }
1059
1060    #[test]
1061    fn test_stratified_pool_stats_per_bucket() {
1062        let pool = StratifiedMemoryPool::new("test");
1063
1064        // Allocate from different buckets
1065        let _buf1 = pool.allocate(100); // Tiny
1066        let _buf2 = pool.allocate(500); // Small
1067        let _buf3 = pool.allocate(100); // Tiny again
1068
1069        let stats = pool.stats();
1070        assert_eq!(stats.total_allocations, 3);
1071        assert_eq!(
1072            stats.allocations_per_bucket.get(&SizeBucket::Tiny),
1073            Some(&2)
1074        );
1075        assert_eq!(
1076            stats.allocations_per_bucket.get(&SizeBucket::Small),
1077            Some(&1)
1078        );
1079    }
1080
1081    #[test]
1082    fn test_stratified_pool_preallocate() {
1083        let pool = StratifiedMemoryPool::new("test");
1084
1085        pool.preallocate(SizeBucket::Medium, 5);
1086        assert_eq!(pool.bucket_size(SizeBucket::Medium), 5);
1087        assert_eq!(pool.bucket_size(SizeBucket::Tiny), 0);
1088
1089        // All medium allocations should hit cache
1090        for _ in 0..5 {
1091            let _buf = pool.allocate(2000);
1092        }
1093
1094        let stats = pool.stats();
1095        assert_eq!(stats.hits_per_bucket.get(&SizeBucket::Medium), Some(&5));
1096    }
1097
1098    #[test]
1099    fn test_stratified_pool_shrink() {
1100        let pool = StratifiedMemoryPool::new("test");
1101
1102        // Preallocate then shrink
1103        pool.preallocate_all(10);
1104        assert_eq!(pool.total_pooled(), 50); // 5 buckets * 10
1105
1106        pool.shrink_to(2);
1107        assert_eq!(pool.total_pooled(), 10); // 5 buckets * 2
1108    }
1109
1110    #[test]
1111    fn test_stratified_buffer_deref() {
1112        let pool = StratifiedMemoryPool::new("test");
1113
1114        let mut buf = pool.allocate(100);
1115
1116        // Write via DerefMut
1117        buf[0] = 42;
1118        buf[1] = 43;
1119
1120        // Read via Deref
1121        assert_eq!(buf[0], 42);
1122        assert_eq!(buf[1], 43);
1123    }
1124
1125    // ========================================================================
1126    // Memory Pressure Reaction Tests
1127    // ========================================================================
1128
1129    #[test]
1130    fn test_pressure_handler_no_reaction() {
1131        let handler = PressureHandler::no_reaction();
1132        assert_eq!(handler.current_level(), MemoryPressureLevel::Normal);
1133
1134        let result = handler.on_pressure_change(MemoryPressureLevel::Critical, 10);
1135        assert!(result.is_none());
1136    }
1137
1138    #[test]
1139    fn test_pressure_handler_shrink() {
1140        let handler = PressureHandler::shrink_to(0.5);
1141
1142        // Normal -> Critical should trigger shrink
1143        let result = handler.on_pressure_change(MemoryPressureLevel::Critical, 10);
1144        assert!(result.is_some());
1145        // With 0.5 target and 0.8 severity, adjusted = 0.5 * (1.0 - 0.8) = 0.1
1146        // 10 * 0.1 = 1 -> max(1, 1) = 1
1147        assert!(result.unwrap() >= 1);
1148    }
1149
1150    #[test]
1151    fn test_pressure_handler_callback() {
1152        use std::sync::atomic::{AtomicBool, Ordering};
1153        use std::sync::Arc;
1154
1155        let called = Arc::new(AtomicBool::new(false));
1156        let called_clone = called.clone();
1157
1158        let handler = PressureHandler::with_callback(move |level| {
1159            if level == MemoryPressureLevel::Critical {
1160                called_clone.store(true, Ordering::SeqCst);
1161            }
1162        });
1163
1164        handler.on_pressure_change(MemoryPressureLevel::Critical, 10);
1165        assert!(called.load(Ordering::SeqCst));
1166    }
1167
1168    #[test]
1169    fn test_pressure_handler_only_reacts_to_increase() {
1170        let handler = PressureHandler::shrink_to(0.5);
1171
1172        // Start at Critical
1173        handler.on_pressure_change(MemoryPressureLevel::Critical, 10);
1174
1175        // Going back to Normal should not trigger
1176        let result = handler.on_pressure_change(MemoryPressureLevel::Normal, 10);
1177        assert!(result.is_none());
1178    }
1179
1180    #[test]
1181    fn test_pressure_handler_level_tracking() {
1182        let handler = PressureHandler::no_reaction();
1183
1184        assert_eq!(handler.current_level(), MemoryPressureLevel::Normal);
1185
1186        handler.on_pressure_change(MemoryPressureLevel::Warning, 10);
1187        assert_eq!(handler.current_level(), MemoryPressureLevel::Warning);
1188
1189        handler.on_pressure_change(MemoryPressureLevel::Critical, 10);
1190        assert_eq!(handler.current_level(), MemoryPressureLevel::Critical);
1191    }
1192
1193    #[test]
1194    fn test_pressure_reaction_debug() {
1195        let none = PressureReaction::None;
1196        assert!(format!("{:?}", none).contains("None"));
1197
1198        let shrink = PressureReaction::Shrink {
1199            target_utilization: 0.5,
1200        };
1201        assert!(format!("{:?}", shrink).contains("0.5"));
1202
1203        let callback = PressureReaction::Callback(Box::new(|_| {}));
1204        assert!(format!("{:?}", callback).contains("Callback"));
1205    }
1206
1207    #[test]
1208    fn test_pressure_handler_debug() {
1209        let handler = PressureHandler::shrink_to(0.3);
1210        let debug_str = format!("{:?}", handler);
1211        assert!(debug_str.contains("PressureHandler"));
1212        assert!(debug_str.contains("Shrink"));
1213    }
1214
1215    #[test]
1216    fn test_pressure_severity_values() {
1217        // Test that severity increases with pressure level
1218        let normal = PressureHandler::pressure_severity(MemoryPressureLevel::Normal);
1219        let elevated = PressureHandler::pressure_severity(MemoryPressureLevel::Elevated);
1220        let warning = PressureHandler::pressure_severity(MemoryPressureLevel::Warning);
1221        let critical = PressureHandler::pressure_severity(MemoryPressureLevel::Critical);
1222        let oom = PressureHandler::pressure_severity(MemoryPressureLevel::OutOfMemory);
1223
1224        assert!(normal < elevated);
1225        assert!(elevated < warning);
1226        assert!(warning < critical);
1227        assert!(critical < oom);
1228        assert!(oom <= 1.0);
1229    }
1230}