ringkernel_core/
memory.rs

1//! GPU and host memory management abstractions.
2//!
3//! This module provides RAII wrappers for GPU memory, pinned host memory,
4//! and memory pools for efficient allocation.
5
6use std::alloc::{alloc, dealloc, Layout};
7use std::marker::PhantomData;
8use std::ptr::NonNull;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::Arc;
11
12use parking_lot::Mutex;
13
14use crate::error::{Result, RingKernelError};
15
16/// Trait for GPU buffer operations.
17pub trait GpuBuffer: Send + Sync {
18    /// Get buffer size in bytes.
19    fn size(&self) -> usize;
20
21    /// Get device pointer (as usize for FFI compatibility).
22    fn device_ptr(&self) -> usize;
23
24    /// Copy data from host to device.
25    fn copy_from_host(&self, data: &[u8]) -> Result<()>;
26
27    /// Copy data from device to host.
28    fn copy_to_host(&self, data: &mut [u8]) -> Result<()>;
29}
30
31/// Trait for device memory allocation.
32pub trait DeviceMemory: Send + Sync {
33    /// Allocate device memory.
34    fn allocate(&self, size: usize) -> Result<Box<dyn GpuBuffer>>;
35
36    /// Allocate device memory with alignment.
37    fn allocate_aligned(&self, size: usize, alignment: usize) -> Result<Box<dyn GpuBuffer>>;
38
39    /// Get total device memory.
40    fn total_memory(&self) -> usize;
41
42    /// Get free device memory.
43    fn free_memory(&self) -> usize;
44}
45
46/// Pinned (page-locked) host memory for efficient DMA transfers.
47///
48/// Pinned memory allows direct DMA transfers between host and device
49/// without intermediate copying, significantly improving transfer performance.
50pub struct PinnedMemory<T: Copy> {
51    ptr: NonNull<T>,
52    len: usize,
53    layout: Layout,
54    _marker: PhantomData<T>,
55}
56
57impl<T: Copy> PinnedMemory<T> {
58    /// Allocate pinned memory for `count` elements.
59    ///
60    /// # Safety
61    ///
62    /// The underlying memory is uninitialized. Caller must ensure
63    /// data is initialized before reading.
64    pub fn new(count: usize) -> Result<Self> {
65        if count == 0 {
66            return Err(RingKernelError::InvalidConfig(
67                "Cannot allocate zero-sized buffer".to_string(),
68            ));
69        }
70
71        let layout =
72            Layout::array::<T>(count).map_err(|_| RingKernelError::HostAllocationFailed {
73                size: count * std::mem::size_of::<T>(),
74            })?;
75
76        // In production, this would use platform-specific pinned allocation
77        // (e.g., cuMemAllocHost for CUDA, or mlock for general case)
78        let ptr = unsafe { alloc(layout) };
79
80        if ptr.is_null() {
81            return Err(RingKernelError::HostAllocationFailed {
82                size: layout.size(),
83            });
84        }
85
86        Ok(Self {
87            ptr: NonNull::new(ptr as *mut T).unwrap(),
88            len: count,
89            layout,
90            _marker: PhantomData,
91        })
92    }
93
94    /// Create pinned memory from a slice, copying the data.
95    pub fn from_slice(data: &[T]) -> Result<Self> {
96        let mut mem = Self::new(data.len())?;
97        mem.as_mut_slice().copy_from_slice(data);
98        Ok(mem)
99    }
100
101    /// Get slice reference.
102    pub fn as_slice(&self) -> &[T] {
103        unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
104    }
105
106    /// Get mutable slice reference.
107    pub fn as_mut_slice(&mut self) -> &mut [T] {
108        unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
109    }
110
111    /// Get raw pointer.
112    pub fn as_ptr(&self) -> *const T {
113        self.ptr.as_ptr()
114    }
115
116    /// Get mutable raw pointer.
117    pub fn as_mut_ptr(&mut self) -> *mut T {
118        self.ptr.as_ptr()
119    }
120
121    /// Get number of elements.
122    pub fn len(&self) -> usize {
123        self.len
124    }
125
126    /// Check if empty.
127    pub fn is_empty(&self) -> bool {
128        self.len == 0
129    }
130
131    /// Get size in bytes.
132    pub fn size_bytes(&self) -> usize {
133        self.len * std::mem::size_of::<T>()
134    }
135}
136
137impl<T: Copy> Drop for PinnedMemory<T> {
138    fn drop(&mut self) {
139        // In production, this would use platform-specific deallocation
140        unsafe {
141            dealloc(self.ptr.as_ptr() as *mut u8, self.layout);
142        }
143    }
144}
145
146// SAFETY: PinnedMemory can be sent between threads
147unsafe impl<T: Copy + Send> Send for PinnedMemory<T> {}
148unsafe impl<T: Copy + Sync> Sync for PinnedMemory<T> {}
149
150/// Memory pool for efficient allocation/deallocation.
151///
152/// Memory pools amortize allocation costs by maintaining a free list
153/// of pre-allocated buffers.
154pub struct MemoryPool {
155    /// Pool name for debugging.
156    name: String,
157    /// Buffer size for this pool.
158    buffer_size: usize,
159    /// Maximum number of buffers to pool.
160    max_buffers: usize,
161    /// Free list of buffers.
162    free_list: Mutex<Vec<Vec<u8>>>,
163    /// Statistics: total allocations.
164    total_allocations: AtomicUsize,
165    /// Statistics: cache hits.
166    cache_hits: AtomicUsize,
167    /// Statistics: current pool size.
168    pool_size: AtomicUsize,
169}
170
171impl MemoryPool {
172    /// Create a new memory pool.
173    pub fn new(name: impl Into<String>, buffer_size: usize, max_buffers: usize) -> Self {
174        Self {
175            name: name.into(),
176            buffer_size,
177            max_buffers,
178            free_list: Mutex::new(Vec::with_capacity(max_buffers)),
179            total_allocations: AtomicUsize::new(0),
180            cache_hits: AtomicUsize::new(0),
181            pool_size: AtomicUsize::new(0),
182        }
183    }
184
185    /// Allocate a buffer from the pool.
186    pub fn allocate(&self) -> PooledBuffer<'_> {
187        self.total_allocations.fetch_add(1, Ordering::Relaxed);
188
189        let buffer = {
190            let mut free = self.free_list.lock();
191            if let Some(buf) = free.pop() {
192                self.cache_hits.fetch_add(1, Ordering::Relaxed);
193                self.pool_size.fetch_sub(1, Ordering::Relaxed);
194                buf
195            } else {
196                vec![0u8; self.buffer_size]
197            }
198        };
199
200        PooledBuffer {
201            buffer: Some(buffer),
202            pool: self,
203        }
204    }
205
206    /// Return a buffer to the pool.
207    fn return_buffer(&self, mut buffer: Vec<u8>) {
208        let mut free = self.free_list.lock();
209        if free.len() < self.max_buffers {
210            buffer.clear();
211            buffer.resize(self.buffer_size, 0);
212            free.push(buffer);
213            self.pool_size.fetch_add(1, Ordering::Relaxed);
214        }
215        // If pool is full, buffer is dropped
216    }
217
218    /// Get pool name.
219    pub fn name(&self) -> &str {
220        &self.name
221    }
222
223    /// Get buffer size.
224    pub fn buffer_size(&self) -> usize {
225        self.buffer_size
226    }
227
228    /// Get current pool size.
229    pub fn current_size(&self) -> usize {
230        self.pool_size.load(Ordering::Relaxed)
231    }
232
233    /// Get cache hit rate.
234    pub fn hit_rate(&self) -> f64 {
235        let total = self.total_allocations.load(Ordering::Relaxed);
236        let hits = self.cache_hits.load(Ordering::Relaxed);
237        if total == 0 {
238            0.0
239        } else {
240            hits as f64 / total as f64
241        }
242    }
243
244    /// Pre-allocate buffers.
245    pub fn preallocate(&self, count: usize) {
246        let count = count.min(self.max_buffers);
247        let mut free = self.free_list.lock();
248        for _ in free.len()..count {
249            free.push(vec![0u8; self.buffer_size]);
250            self.pool_size.fetch_add(1, Ordering::Relaxed);
251        }
252    }
253}
254
255/// A buffer from a memory pool.
256///
257/// When dropped, the buffer is returned to the pool for reuse.
258pub struct PooledBuffer<'a> {
259    buffer: Option<Vec<u8>>,
260    pool: &'a MemoryPool,
261}
262
263impl<'a> PooledBuffer<'a> {
264    /// Get slice reference.
265    pub fn as_slice(&self) -> &[u8] {
266        self.buffer.as_deref().unwrap_or(&[])
267    }
268
269    /// Get mutable slice reference.
270    pub fn as_mut_slice(&mut self) -> &mut [u8] {
271        self.buffer.as_deref_mut().unwrap_or(&mut [])
272    }
273
274    /// Get buffer length.
275    pub fn len(&self) -> usize {
276        self.buffer.as_ref().map(|b| b.len()).unwrap_or(0)
277    }
278
279    /// Check if empty.
280    pub fn is_empty(&self) -> bool {
281        self.len() == 0
282    }
283}
284
285impl<'a> Drop for PooledBuffer<'a> {
286    fn drop(&mut self) {
287        if let Some(buffer) = self.buffer.take() {
288            self.pool.return_buffer(buffer);
289        }
290    }
291}
292
293impl<'a> std::ops::Deref for PooledBuffer<'a> {
294    type Target = [u8];
295
296    fn deref(&self) -> &Self::Target {
297        self.as_slice()
298    }
299}
300
301impl<'a> std::ops::DerefMut for PooledBuffer<'a> {
302    fn deref_mut(&mut self) -> &mut Self::Target {
303        self.as_mut_slice()
304    }
305}
306
307/// Shared memory pool that can be cloned.
308pub type SharedMemoryPool = Arc<MemoryPool>;
309
310/// Create a shared memory pool.
311pub fn create_pool(
312    name: impl Into<String>,
313    buffer_size: usize,
314    max_buffers: usize,
315) -> SharedMemoryPool {
316    Arc::new(MemoryPool::new(name, buffer_size, max_buffers))
317}
318
319// ============================================================================
320// Size-Stratified Memory Pool
321// ============================================================================
322
323/// Size bucket for stratified pooling.
324///
325/// Provides predefined size classes for efficient multi-size pooling.
326/// Allocations are rounded up to the smallest bucket that fits.
327#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
328pub enum SizeBucket {
329    /// Tiny buffers (256 bytes) - metadata, small messages.
330    Tiny,
331    /// Small buffers (1 KB) - typical message payloads.
332    Small,
333    /// Medium buffers (4 KB) - page-sized allocations.
334    #[default]
335    Medium,
336    /// Large buffers (16 KB) - batch operations.
337    Large,
338    /// Huge buffers (64 KB) - large transfers.
339    Huge,
340}
341
342impl SizeBucket {
343    /// All bucket variants in order from smallest to largest.
344    pub const ALL: [SizeBucket; 5] = [
345        SizeBucket::Tiny,
346        SizeBucket::Small,
347        SizeBucket::Medium,
348        SizeBucket::Large,
349        SizeBucket::Huge,
350    ];
351
352    /// Get the size in bytes for this bucket.
353    pub fn size(&self) -> usize {
354        match self {
355            Self::Tiny => 256,
356            Self::Small => 1024,
357            Self::Medium => 4096,
358            Self::Large => 16384,
359            Self::Huge => 65536,
360        }
361    }
362
363    /// Find the smallest bucket that fits the requested size.
364    ///
365    /// Returns `Huge` for any size larger than 16KB.
366    pub fn for_size(requested: usize) -> Self {
367        if requested <= 256 {
368            Self::Tiny
369        } else if requested <= 1024 {
370            Self::Small
371        } else if requested <= 4096 {
372            Self::Medium
373        } else if requested <= 16384 {
374            Self::Large
375        } else {
376            Self::Huge
377        }
378    }
379
380    /// Get the next larger bucket, or self if already at largest.
381    pub fn upgrade(&self) -> Self {
382        match self {
383            Self::Tiny => Self::Small,
384            Self::Small => Self::Medium,
385            Self::Medium => Self::Large,
386            Self::Large => Self::Huge,
387            Self::Huge => Self::Huge,
388        }
389    }
390
391    /// Get the next smaller bucket, or self if already at smallest.
392    pub fn downgrade(&self) -> Self {
393        match self {
394            Self::Tiny => Self::Tiny,
395            Self::Small => Self::Tiny,
396            Self::Medium => Self::Small,
397            Self::Large => Self::Medium,
398            Self::Huge => Self::Large,
399        }
400    }
401}
402
403impl std::fmt::Display for SizeBucket {
404    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
405        match self {
406            Self::Tiny => write!(f, "Tiny(256B)"),
407            Self::Small => write!(f, "Small(1KB)"),
408            Self::Medium => write!(f, "Medium(4KB)"),
409            Self::Large => write!(f, "Large(16KB)"),
410            Self::Huge => write!(f, "Huge(64KB)"),
411        }
412    }
413}
414
415/// Statistics for a stratified memory pool.
416#[derive(Debug, Clone, Default)]
417pub struct StratifiedPoolStats {
418    /// Total allocations across all buckets.
419    pub total_allocations: usize,
420    /// Total cache hits across all buckets.
421    pub total_hits: usize,
422    /// Allocations per bucket.
423    pub allocations_per_bucket: std::collections::HashMap<SizeBucket, usize>,
424    /// Hits per bucket.
425    pub hits_per_bucket: std::collections::HashMap<SizeBucket, usize>,
426}
427
428impl StratifiedPoolStats {
429    /// Calculate overall hit rate.
430    pub fn hit_rate(&self) -> f64 {
431        if self.total_allocations == 0 {
432            0.0
433        } else {
434            self.total_hits as f64 / self.total_allocations as f64
435        }
436    }
437
438    /// Get hit rate for a specific bucket.
439    pub fn bucket_hit_rate(&self, bucket: SizeBucket) -> f64 {
440        let allocs = self
441            .allocations_per_bucket
442            .get(&bucket)
443            .copied()
444            .unwrap_or(0);
445        let hits = self.hits_per_bucket.get(&bucket).copied().unwrap_or(0);
446        if allocs == 0 {
447            0.0
448        } else {
449            hits as f64 / allocs as f64
450        }
451    }
452}
453
454/// Multi-size memory pool with automatic bucket selection.
455///
456/// Instead of having a single buffer size, this pool maintains separate
457/// pools for different size classes. Allocations are rounded up to the
458/// smallest bucket that fits.
459///
460/// # Example
461///
462/// ```ignore
463/// use ringkernel_core::memory::{StratifiedMemoryPool, SizeBucket};
464///
465/// let pool = StratifiedMemoryPool::new("my_pool");
466///
467/// // Allocate various sizes - each goes to appropriate bucket
468/// let tiny_buf = pool.allocate(100);   // Uses Tiny bucket (256B)
469/// let medium_buf = pool.allocate(2000); // Uses Medium bucket (4KB)
470///
471/// // Check statistics
472/// let stats = pool.stats();
473/// println!("Hit rate: {:.1}%", stats.hit_rate() * 100.0);
474/// ```
475pub struct StratifiedMemoryPool {
476    name: String,
477    buckets: std::collections::HashMap<SizeBucket, MemoryPool>,
478    max_buffers_per_bucket: usize,
479    stats: Mutex<StratifiedPoolStats>,
480}
481
482impl StratifiedMemoryPool {
483    /// Create a new stratified pool with default settings.
484    ///
485    /// Creates pools for all bucket sizes with 16 buffers per bucket.
486    pub fn new(name: impl Into<String>) -> Self {
487        Self::with_capacity(name, 16)
488    }
489
490    /// Create a pool with specified max buffers per bucket.
491    pub fn with_capacity(name: impl Into<String>, max_buffers_per_bucket: usize) -> Self {
492        let name = name.into();
493        let mut buckets = std::collections::HashMap::new();
494
495        for bucket in SizeBucket::ALL {
496            let pool_name = format!("{}_{}", name, bucket);
497            buckets.insert(
498                bucket,
499                MemoryPool::new(pool_name, bucket.size(), max_buffers_per_bucket),
500            );
501        }
502
503        Self {
504            name,
505            buckets,
506            max_buffers_per_bucket,
507            stats: Mutex::new(StratifiedPoolStats::default()),
508        }
509    }
510
511    /// Allocate a buffer of at least the requested size.
512    ///
513    /// The buffer may be larger than requested (rounded up to bucket size).
514    pub fn allocate(&self, size: usize) -> StratifiedBuffer<'_> {
515        let bucket = SizeBucket::for_size(size);
516        self.allocate_bucket(bucket)
517    }
518
519    /// Allocate from a specific bucket.
520    pub fn allocate_bucket(&self, bucket: SizeBucket) -> StratifiedBuffer<'_> {
521        let pool = self.buckets.get(&bucket).expect("bucket pool exists");
522
523        // Track stats before allocation to capture hit
524        let was_cached = pool.current_size() > 0;
525        let buffer = pool.allocate();
526
527        // Update stats
528        {
529            let mut stats = self.stats.lock();
530            stats.total_allocations += 1;
531            *stats.allocations_per_bucket.entry(bucket).or_insert(0) += 1;
532            if was_cached {
533                stats.total_hits += 1;
534                *stats.hits_per_bucket.entry(bucket).or_insert(0) += 1;
535            }
536        }
537
538        StratifiedBuffer {
539            inner: buffer,
540            bucket,
541            pool: self,
542        }
543    }
544
545    /// Get pool name.
546    pub fn name(&self) -> &str {
547        &self.name
548    }
549
550    /// Get max buffers per bucket.
551    pub fn max_buffers_per_bucket(&self) -> usize {
552        self.max_buffers_per_bucket
553    }
554
555    /// Get current size of a specific bucket pool.
556    pub fn bucket_size(&self, bucket: SizeBucket) -> usize {
557        self.buckets
558            .get(&bucket)
559            .map(|p| p.current_size())
560            .unwrap_or(0)
561    }
562
563    /// Get total buffers currently pooled across all buckets.
564    pub fn total_pooled(&self) -> usize {
565        self.buckets.values().map(|p| p.current_size()).sum()
566    }
567
568    /// Get statistics snapshot.
569    pub fn stats(&self) -> StratifiedPoolStats {
570        self.stats.lock().clone()
571    }
572
573    /// Pre-allocate buffers for a specific bucket.
574    pub fn preallocate(&self, bucket: SizeBucket, count: usize) {
575        if let Some(pool) = self.buckets.get(&bucket) {
576            pool.preallocate(count);
577        }
578    }
579
580    /// Pre-allocate buffers for all buckets.
581    pub fn preallocate_all(&self, count_per_bucket: usize) {
582        for bucket in SizeBucket::ALL {
583            self.preallocate(bucket, count_per_bucket);
584        }
585    }
586
587    /// Shrink all pools to target utilization.
588    ///
589    /// Removes excess pooled buffers to free memory.
590    pub fn shrink_to(&self, target_per_bucket: usize) {
591        for pool in self.buckets.values() {
592            let mut free_list = pool.free_list.lock();
593            while free_list.len() > target_per_bucket {
594                free_list.pop();
595                pool.pool_size.fetch_sub(1, Ordering::Relaxed);
596            }
597        }
598    }
599}
600
601/// A buffer from a stratified memory pool.
602///
603/// Tracks which bucket it came from for proper return.
604pub struct StratifiedBuffer<'a> {
605    inner: PooledBuffer<'a>,
606    bucket: SizeBucket,
607    #[allow(dead_code)]
608    pool: &'a StratifiedMemoryPool,
609}
610
611impl<'a> StratifiedBuffer<'a> {
612    /// Get the size bucket this buffer was allocated from.
613    pub fn bucket(&self) -> SizeBucket {
614        self.bucket
615    }
616
617    /// Get the actual capacity (bucket size, may be larger than requested).
618    pub fn capacity(&self) -> usize {
619        self.bucket.size()
620    }
621
622    /// Get slice reference.
623    pub fn as_slice(&self) -> &[u8] {
624        self.inner.as_slice()
625    }
626
627    /// Get mutable slice reference.
628    pub fn as_mut_slice(&mut self) -> &mut [u8] {
629        self.inner.as_mut_slice()
630    }
631
632    /// Get buffer length.
633    pub fn len(&self) -> usize {
634        self.inner.len()
635    }
636
637    /// Check if empty.
638    pub fn is_empty(&self) -> bool {
639        self.inner.is_empty()
640    }
641}
642
643impl<'a> std::ops::Deref for StratifiedBuffer<'a> {
644    type Target = [u8];
645
646    fn deref(&self) -> &Self::Target {
647        self.as_slice()
648    }
649}
650
651impl<'a> std::ops::DerefMut for StratifiedBuffer<'a> {
652    fn deref_mut(&mut self) -> &mut Self::Target {
653        self.as_mut_slice()
654    }
655}
656
657/// Shared stratified memory pool.
658pub type SharedStratifiedPool = Arc<StratifiedMemoryPool>;
659
660/// Create a shared stratified memory pool.
661pub fn create_stratified_pool(name: impl Into<String>) -> SharedStratifiedPool {
662    Arc::new(StratifiedMemoryPool::new(name))
663}
664
665/// Create a shared stratified memory pool with custom capacity.
666pub fn create_stratified_pool_with_capacity(
667    name: impl Into<String>,
668    max_buffers_per_bucket: usize,
669) -> SharedStratifiedPool {
670    Arc::new(StratifiedMemoryPool::with_capacity(
671        name,
672        max_buffers_per_bucket,
673    ))
674}
675
676// ============================================================================
677// Memory Pressure Reactions
678// ============================================================================
679
680use crate::observability::MemoryPressureLevel;
681
682/// Callback type for memory pressure changes.
683pub type PressureCallback = Box<dyn Fn(MemoryPressureLevel) + Send + Sync>;
684
685/// Reaction to memory pressure events.
686///
687/// Pools can be configured to react to memory pressure by shrinking
688/// or invoking custom callbacks.
689pub enum PressureReaction {
690    /// No automatic reaction to pressure.
691    None,
692    /// Automatically shrink pool to target utilization.
693    ///
694    /// The `target_utilization` is a fraction (0.0 to 1.0) of the max
695    /// pool size to retain when under pressure.
696    Shrink {
697        /// Target utilization as fraction of max capacity.
698        target_utilization: f64,
699    },
700    /// Invoke a custom callback on pressure change.
701    Callback(PressureCallback),
702}
703
704impl std::fmt::Debug for PressureReaction {
705    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
706        match self {
707            Self::None => write!(f, "PressureReaction::None"),
708            Self::Shrink { target_utilization } => {
709                write!(
710                    f,
711                    "PressureReaction::Shrink {{ target_utilization: {} }}",
712                    target_utilization
713                )
714            }
715            Self::Callback(_) => write!(f, "PressureReaction::Callback(<fn>)"),
716        }
717    }
718}
719
720/// Memory pressure handler for stratified pools.
721///
722/// Monitors memory pressure levels and triggers configured reactions.
723pub struct PressureHandler {
724    /// Configured reaction.
725    reaction: PressureReaction,
726    /// Current pressure level.
727    current_level: Mutex<MemoryPressureLevel>,
728}
729
730impl PressureHandler {
731    /// Create a new pressure handler with the specified reaction.
732    pub fn new(reaction: PressureReaction) -> Self {
733        Self {
734            reaction,
735            current_level: Mutex::new(MemoryPressureLevel::Normal),
736        }
737    }
738
739    /// Create a handler with no reaction.
740    pub fn no_reaction() -> Self {
741        Self::new(PressureReaction::None)
742    }
743
744    /// Create a handler that shrinks to target utilization.
745    pub fn shrink_to(target_utilization: f64) -> Self {
746        Self::new(PressureReaction::Shrink {
747            target_utilization: target_utilization.clamp(0.0, 1.0),
748        })
749    }
750
751    /// Create a handler with a custom callback.
752    pub fn with_callback<F>(callback: F) -> Self
753    where
754        F: Fn(MemoryPressureLevel) + Send + Sync + 'static,
755    {
756        Self::new(PressureReaction::Callback(Box::new(callback)))
757    }
758
759    /// Get the current pressure level.
760    pub fn current_level(&self) -> MemoryPressureLevel {
761        *self.current_level.lock()
762    }
763
764    /// Handle a pressure level change.
765    ///
766    /// Returns the number of buffers to retain per bucket (if shrinking).
767    pub fn on_pressure_change(
768        &self,
769        new_level: MemoryPressureLevel,
770        max_per_bucket: usize,
771    ) -> Option<usize> {
772        let old_level = {
773            let mut current = self.current_level.lock();
774            let old = *current;
775            *current = new_level;
776            old
777        };
778
779        // Only react if pressure increased
780        if !Self::is_higher_pressure(new_level, old_level) {
781            return None;
782        }
783
784        match &self.reaction {
785            PressureReaction::None => None,
786            PressureReaction::Shrink { target_utilization } => {
787                // Calculate target based on pressure level
788                let pressure_factor = Self::pressure_severity(new_level);
789                let adjusted_target = target_utilization * (1.0 - pressure_factor);
790                let target_count = ((max_per_bucket as f64) * adjusted_target) as usize;
791                Some(target_count.max(1)) // Keep at least 1
792            }
793            PressureReaction::Callback(callback) => {
794                callback(new_level);
795                None
796            }
797        }
798    }
799
800    /// Check if new pressure level is higher than old.
801    fn is_higher_pressure(new: MemoryPressureLevel, old: MemoryPressureLevel) -> bool {
802        Self::pressure_ordinal(new) > Self::pressure_ordinal(old)
803    }
804
805    /// Get ordinal value for pressure level comparison.
806    fn pressure_ordinal(level: MemoryPressureLevel) -> u8 {
807        match level {
808            MemoryPressureLevel::Normal => 0,
809            MemoryPressureLevel::Elevated => 1,
810            MemoryPressureLevel::Warning => 2,
811            MemoryPressureLevel::Critical => 3,
812            MemoryPressureLevel::OutOfMemory => 4,
813        }
814    }
815
816    /// Get severity factor (0.0 to 1.0) for pressure level.
817    fn pressure_severity(level: MemoryPressureLevel) -> f64 {
818        match level {
819            MemoryPressureLevel::Normal => 0.0,
820            MemoryPressureLevel::Elevated => 0.2,
821            MemoryPressureLevel::Warning => 0.5,
822            MemoryPressureLevel::Critical => 0.8,
823            MemoryPressureLevel::OutOfMemory => 1.0,
824        }
825    }
826}
827
828impl std::fmt::Debug for PressureHandler {
829    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
830        f.debug_struct("PressureHandler")
831            .field("reaction", &self.reaction)
832            .field("current_level", &self.current_level())
833            .finish()
834    }
835}
836
837/// Extension trait for pressure-aware memory pools.
838pub trait PressureAwarePool {
839    /// Handle a memory pressure change event.
840    ///
841    /// Returns true if the pool took action (e.g., shrunk).
842    fn handle_pressure(&self, level: MemoryPressureLevel) -> bool;
843
844    /// Get current pressure level.
845    fn pressure_level(&self) -> MemoryPressureLevel;
846}
847
848/// Alignment utilities.
849pub mod align {
850    /// Cache line size (64 bytes on most modern CPUs).
851    pub const CACHE_LINE_SIZE: usize = 64;
852
853    /// GPU cache line size (128 bytes on many GPUs).
854    pub const GPU_CACHE_LINE_SIZE: usize = 128;
855
856    /// Align a value up to the next multiple of alignment.
857    #[inline]
858    pub const fn align_up(value: usize, alignment: usize) -> usize {
859        let mask = alignment - 1;
860        (value + mask) & !mask
861    }
862
863    /// Align a value down to the previous multiple of alignment.
864    #[inline]
865    pub const fn align_down(value: usize, alignment: usize) -> usize {
866        let mask = alignment - 1;
867        value & !mask
868    }
869
870    /// Check if a value is aligned.
871    #[inline]
872    pub const fn is_aligned(value: usize, alignment: usize) -> bool {
873        value & (alignment - 1) == 0
874    }
875
876    /// Get required padding for alignment.
877    #[inline]
878    pub const fn padding_for(offset: usize, alignment: usize) -> usize {
879        let misalignment = offset & (alignment - 1);
880        if misalignment == 0 {
881            0
882        } else {
883            alignment - misalignment
884        }
885    }
886}
887
888#[cfg(test)]
889mod tests {
890    use super::*;
891
892    #[test]
893    fn test_pinned_memory() {
894        let mut mem = PinnedMemory::<f32>::new(1024).unwrap();
895        assert_eq!(mem.len(), 1024);
896        assert_eq!(mem.size_bytes(), 1024 * 4);
897
898        // Write some data
899        let slice = mem.as_mut_slice();
900        for (i, v) in slice.iter_mut().enumerate() {
901            *v = i as f32;
902        }
903
904        // Read back
905        assert_eq!(mem.as_slice()[42], 42.0);
906    }
907
908    #[test]
909    fn test_pinned_memory_from_slice() {
910        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
911        let mem = PinnedMemory::from_slice(&data).unwrap();
912        assert_eq!(mem.as_slice(), &data[..]);
913    }
914
915    #[test]
916    fn test_memory_pool() {
917        let pool = MemoryPool::new("test", 1024, 10);
918
919        // First allocation should be fresh
920        let buf1 = pool.allocate();
921        assert_eq!(buf1.len(), 1024);
922        drop(buf1);
923
924        // Second allocation should be cached
925        let _buf2 = pool.allocate();
926        assert_eq!(pool.hit_rate(), 0.5); // 1 hit out of 2 allocations
927    }
928
929    #[test]
930    fn test_pool_preallocate() {
931        let pool = MemoryPool::new("test", 1024, 10);
932        pool.preallocate(5);
933        assert_eq!(pool.current_size(), 5);
934
935        // All allocations should hit cache
936        for _ in 0..5 {
937            let _ = pool.allocate();
938        }
939        assert_eq!(pool.hit_rate(), 1.0);
940    }
941
942    #[test]
943    fn test_align_up() {
944        use align::*;
945
946        assert_eq!(align_up(0, 64), 0);
947        assert_eq!(align_up(1, 64), 64);
948        assert_eq!(align_up(64, 64), 64);
949        assert_eq!(align_up(65, 64), 128);
950    }
951
952    #[test]
953    fn test_is_aligned() {
954        use align::*;
955
956        assert!(is_aligned(0, 64));
957        assert!(is_aligned(64, 64));
958        assert!(is_aligned(128, 64));
959        assert!(!is_aligned(1, 64));
960        assert!(!is_aligned(63, 64));
961    }
962
963    #[test]
964    fn test_padding_for() {
965        use align::*;
966
967        assert_eq!(padding_for(0, 64), 0);
968        assert_eq!(padding_for(1, 64), 63);
969        assert_eq!(padding_for(63, 64), 1);
970        assert_eq!(padding_for(64, 64), 0);
971    }
972
973    // ========================================================================
974    // Size-Stratified Pool Tests
975    // ========================================================================
976
977    #[test]
978    fn test_size_bucket_sizes() {
979        assert_eq!(SizeBucket::Tiny.size(), 256);
980        assert_eq!(SizeBucket::Small.size(), 1024);
981        assert_eq!(SizeBucket::Medium.size(), 4096);
982        assert_eq!(SizeBucket::Large.size(), 16384);
983        assert_eq!(SizeBucket::Huge.size(), 65536);
984    }
985
986    #[test]
987    fn test_size_bucket_selection() {
988        // Exact boundaries
989        assert_eq!(SizeBucket::for_size(0), SizeBucket::Tiny);
990        assert_eq!(SizeBucket::for_size(256), SizeBucket::Tiny);
991        assert_eq!(SizeBucket::for_size(257), SizeBucket::Small);
992        assert_eq!(SizeBucket::for_size(1024), SizeBucket::Small);
993        assert_eq!(SizeBucket::for_size(1025), SizeBucket::Medium);
994        assert_eq!(SizeBucket::for_size(4096), SizeBucket::Medium);
995        assert_eq!(SizeBucket::for_size(4097), SizeBucket::Large);
996        assert_eq!(SizeBucket::for_size(16384), SizeBucket::Large);
997        assert_eq!(SizeBucket::for_size(16385), SizeBucket::Huge);
998        assert_eq!(SizeBucket::for_size(100000), SizeBucket::Huge);
999    }
1000
1001    #[test]
1002    fn test_size_bucket_upgrade_downgrade() {
1003        assert_eq!(SizeBucket::Tiny.upgrade(), SizeBucket::Small);
1004        assert_eq!(SizeBucket::Small.upgrade(), SizeBucket::Medium);
1005        assert_eq!(SizeBucket::Medium.upgrade(), SizeBucket::Large);
1006        assert_eq!(SizeBucket::Large.upgrade(), SizeBucket::Huge);
1007        assert_eq!(SizeBucket::Huge.upgrade(), SizeBucket::Huge); // Max
1008
1009        assert_eq!(SizeBucket::Tiny.downgrade(), SizeBucket::Tiny); // Min
1010        assert_eq!(SizeBucket::Small.downgrade(), SizeBucket::Tiny);
1011        assert_eq!(SizeBucket::Medium.downgrade(), SizeBucket::Small);
1012        assert_eq!(SizeBucket::Large.downgrade(), SizeBucket::Medium);
1013        assert_eq!(SizeBucket::Huge.downgrade(), SizeBucket::Large);
1014    }
1015
1016    #[test]
1017    fn test_stratified_pool_allocation() {
1018        let pool = StratifiedMemoryPool::new("test");
1019
1020        // Allocate different sizes
1021        let buf1 = pool.allocate(100); // Tiny
1022        let buf2 = pool.allocate(500); // Small
1023        let buf3 = pool.allocate(2000); // Medium
1024
1025        assert_eq!(buf1.bucket(), SizeBucket::Tiny);
1026        assert_eq!(buf2.bucket(), SizeBucket::Small);
1027        assert_eq!(buf3.bucket(), SizeBucket::Medium);
1028
1029        // Buffers have full bucket capacity
1030        assert_eq!(buf1.capacity(), 256);
1031        assert_eq!(buf2.capacity(), 1024);
1032        assert_eq!(buf3.capacity(), 4096);
1033    }
1034
1035    #[test]
1036    fn test_stratified_pool_reuse() {
1037        let pool = StratifiedMemoryPool::new("test");
1038
1039        // First allocation - fresh
1040        {
1041            let _buf = pool.allocate(100);
1042        }
1043        // Buffer returned to pool
1044
1045        // Second allocation - should reuse
1046        {
1047            let _buf = pool.allocate(100);
1048        }
1049
1050        let stats = pool.stats();
1051        assert_eq!(stats.total_allocations, 2);
1052        assert_eq!(stats.total_hits, 1);
1053        assert!((stats.hit_rate() - 0.5).abs() < 0.001);
1054    }
1055
1056    #[test]
1057    fn test_stratified_pool_stats_per_bucket() {
1058        let pool = StratifiedMemoryPool::new("test");
1059
1060        // Allocate from different buckets
1061        let _buf1 = pool.allocate(100); // Tiny
1062        let _buf2 = pool.allocate(500); // Small
1063        let _buf3 = pool.allocate(100); // Tiny again
1064
1065        let stats = pool.stats();
1066        assert_eq!(stats.total_allocations, 3);
1067        assert_eq!(
1068            stats.allocations_per_bucket.get(&SizeBucket::Tiny),
1069            Some(&2)
1070        );
1071        assert_eq!(
1072            stats.allocations_per_bucket.get(&SizeBucket::Small),
1073            Some(&1)
1074        );
1075    }
1076
1077    #[test]
1078    fn test_stratified_pool_preallocate() {
1079        let pool = StratifiedMemoryPool::new("test");
1080
1081        pool.preallocate(SizeBucket::Medium, 5);
1082        assert_eq!(pool.bucket_size(SizeBucket::Medium), 5);
1083        assert_eq!(pool.bucket_size(SizeBucket::Tiny), 0);
1084
1085        // All medium allocations should hit cache
1086        for _ in 0..5 {
1087            let _buf = pool.allocate(2000);
1088        }
1089
1090        let stats = pool.stats();
1091        assert_eq!(stats.hits_per_bucket.get(&SizeBucket::Medium), Some(&5));
1092    }
1093
1094    #[test]
1095    fn test_stratified_pool_shrink() {
1096        let pool = StratifiedMemoryPool::new("test");
1097
1098        // Preallocate then shrink
1099        pool.preallocate_all(10);
1100        assert_eq!(pool.total_pooled(), 50); // 5 buckets * 10
1101
1102        pool.shrink_to(2);
1103        assert_eq!(pool.total_pooled(), 10); // 5 buckets * 2
1104    }
1105
1106    #[test]
1107    fn test_stratified_buffer_deref() {
1108        let pool = StratifiedMemoryPool::new("test");
1109
1110        let mut buf = pool.allocate(100);
1111
1112        // Write via DerefMut
1113        buf[0] = 42;
1114        buf[1] = 43;
1115
1116        // Read via Deref
1117        assert_eq!(buf[0], 42);
1118        assert_eq!(buf[1], 43);
1119    }
1120
1121    // ========================================================================
1122    // Memory Pressure Reaction Tests
1123    // ========================================================================
1124
1125    #[test]
1126    fn test_pressure_handler_no_reaction() {
1127        let handler = PressureHandler::no_reaction();
1128        assert_eq!(handler.current_level(), MemoryPressureLevel::Normal);
1129
1130        let result = handler.on_pressure_change(MemoryPressureLevel::Critical, 10);
1131        assert!(result.is_none());
1132    }
1133
1134    #[test]
1135    fn test_pressure_handler_shrink() {
1136        let handler = PressureHandler::shrink_to(0.5);
1137
1138        // Normal -> Critical should trigger shrink
1139        let result = handler.on_pressure_change(MemoryPressureLevel::Critical, 10);
1140        assert!(result.is_some());
1141        // With 0.5 target and 0.8 severity, adjusted = 0.5 * (1.0 - 0.8) = 0.1
1142        // 10 * 0.1 = 1 -> max(1, 1) = 1
1143        assert!(result.unwrap() >= 1);
1144    }
1145
1146    #[test]
1147    fn test_pressure_handler_callback() {
1148        use std::sync::atomic::{AtomicBool, Ordering};
1149        use std::sync::Arc;
1150
1151        let called = Arc::new(AtomicBool::new(false));
1152        let called_clone = called.clone();
1153
1154        let handler = PressureHandler::with_callback(move |level| {
1155            if level == MemoryPressureLevel::Critical {
1156                called_clone.store(true, Ordering::SeqCst);
1157            }
1158        });
1159
1160        handler.on_pressure_change(MemoryPressureLevel::Critical, 10);
1161        assert!(called.load(Ordering::SeqCst));
1162    }
1163
1164    #[test]
1165    fn test_pressure_handler_only_reacts_to_increase() {
1166        let handler = PressureHandler::shrink_to(0.5);
1167
1168        // Start at Critical
1169        handler.on_pressure_change(MemoryPressureLevel::Critical, 10);
1170
1171        // Going back to Normal should not trigger
1172        let result = handler.on_pressure_change(MemoryPressureLevel::Normal, 10);
1173        assert!(result.is_none());
1174    }
1175
1176    #[test]
1177    fn test_pressure_handler_level_tracking() {
1178        let handler = PressureHandler::no_reaction();
1179
1180        assert_eq!(handler.current_level(), MemoryPressureLevel::Normal);
1181
1182        handler.on_pressure_change(MemoryPressureLevel::Warning, 10);
1183        assert_eq!(handler.current_level(), MemoryPressureLevel::Warning);
1184
1185        handler.on_pressure_change(MemoryPressureLevel::Critical, 10);
1186        assert_eq!(handler.current_level(), MemoryPressureLevel::Critical);
1187    }
1188
1189    #[test]
1190    fn test_pressure_reaction_debug() {
1191        let none = PressureReaction::None;
1192        assert!(format!("{:?}", none).contains("None"));
1193
1194        let shrink = PressureReaction::Shrink {
1195            target_utilization: 0.5,
1196        };
1197        assert!(format!("{:?}", shrink).contains("0.5"));
1198
1199        let callback = PressureReaction::Callback(Box::new(|_| {}));
1200        assert!(format!("{:?}", callback).contains("Callback"));
1201    }
1202
1203    #[test]
1204    fn test_pressure_handler_debug() {
1205        let handler = PressureHandler::shrink_to(0.3);
1206        let debug_str = format!("{:?}", handler);
1207        assert!(debug_str.contains("PressureHandler"));
1208        assert!(debug_str.contains("Shrink"));
1209    }
1210
1211    #[test]
1212    fn test_pressure_severity_values() {
1213        // Test that severity increases with pressure level
1214        let normal = PressureHandler::pressure_severity(MemoryPressureLevel::Normal);
1215        let elevated = PressureHandler::pressure_severity(MemoryPressureLevel::Elevated);
1216        let warning = PressureHandler::pressure_severity(MemoryPressureLevel::Warning);
1217        let critical = PressureHandler::pressure_severity(MemoryPressureLevel::Critical);
1218        let oom = PressureHandler::pressure_severity(MemoryPressureLevel::OutOfMemory);
1219
1220        assert!(normal < elevated);
1221        assert!(elevated < warning);
1222        assert!(warning < critical);
1223        assert!(critical < oom);
1224        assert!(oom <= 1.0);
1225    }
1226}