1use std::alloc::{alloc, dealloc, Layout};
7use std::marker::PhantomData;
8use std::ptr::NonNull;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::Arc;
11
12use parking_lot::Mutex;
13
14use crate::error::{Result, RingKernelError};
15
16pub trait GpuBuffer: Send + Sync {
18 fn size(&self) -> usize;
20
21 fn device_ptr(&self) -> usize;
23
24 fn copy_from_host(&self, data: &[u8]) -> Result<()>;
26
27 fn copy_to_host(&self, data: &mut [u8]) -> Result<()>;
29}
30
31pub trait DeviceMemory: Send + Sync {
33 fn allocate(&self, size: usize) -> Result<Box<dyn GpuBuffer>>;
35
36 fn allocate_aligned(&self, size: usize, alignment: usize) -> Result<Box<dyn GpuBuffer>>;
38
39 fn total_memory(&self) -> usize;
41
42 fn free_memory(&self) -> usize;
44}
45
46pub struct PinnedMemory<T: Copy> {
51 ptr: NonNull<T>,
52 len: usize,
53 layout: Layout,
54 _marker: PhantomData<T>,
55}
56
57impl<T: Copy> PinnedMemory<T> {
58 pub fn new(count: usize) -> Result<Self> {
65 if count == 0 {
66 return Err(RingKernelError::InvalidConfig(
67 "Cannot allocate zero-sized buffer".to_string(),
68 ));
69 }
70
71 let layout =
72 Layout::array::<T>(count).map_err(|_| RingKernelError::HostAllocationFailed {
73 size: count * std::mem::size_of::<T>(),
74 })?;
75
76 let ptr = unsafe { alloc(layout) };
79
80 if ptr.is_null() {
81 return Err(RingKernelError::HostAllocationFailed {
82 size: layout.size(),
83 });
84 }
85
86 Ok(Self {
87 ptr: NonNull::new(ptr as *mut T).expect("ptr verified non-null above"),
89 len: count,
90 layout,
91 _marker: PhantomData,
92 })
93 }
94
95 pub fn from_slice(data: &[T]) -> Result<Self> {
97 let mut mem = Self::new(data.len())?;
98 mem.as_mut_slice().copy_from_slice(data);
99 Ok(mem)
100 }
101
102 pub fn as_slice(&self) -> &[T] {
104 unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
105 }
106
107 pub fn as_mut_slice(&mut self) -> &mut [T] {
109 unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
110 }
111
112 pub fn as_ptr(&self) -> *const T {
114 self.ptr.as_ptr()
115 }
116
117 pub fn as_mut_ptr(&mut self) -> *mut T {
119 self.ptr.as_ptr()
120 }
121
122 pub fn len(&self) -> usize {
124 self.len
125 }
126
127 pub fn is_empty(&self) -> bool {
129 self.len == 0
130 }
131
132 pub fn size_bytes(&self) -> usize {
134 self.len * std::mem::size_of::<T>()
135 }
136}
137
138impl<T: Copy> Drop for PinnedMemory<T> {
139 fn drop(&mut self) {
140 unsafe {
142 dealloc(self.ptr.as_ptr() as *mut u8, self.layout);
143 }
144 }
145}
146
147unsafe impl<T: Copy + Send> Send for PinnedMemory<T> {}
149unsafe impl<T: Copy + Sync> Sync for PinnedMemory<T> {}
150
151pub struct MemoryPool {
156 name: String,
158 buffer_size: usize,
160 max_buffers: usize,
162 free_list: Mutex<Vec<Vec<u8>>>,
164 total_allocations: AtomicUsize,
166 cache_hits: AtomicUsize,
168 pool_size: AtomicUsize,
170}
171
172impl MemoryPool {
173 pub fn new(name: impl Into<String>, buffer_size: usize, max_buffers: usize) -> Self {
175 Self {
176 name: name.into(),
177 buffer_size,
178 max_buffers,
179 free_list: Mutex::new(Vec::with_capacity(max_buffers)),
180 total_allocations: AtomicUsize::new(0),
181 cache_hits: AtomicUsize::new(0),
182 pool_size: AtomicUsize::new(0),
183 }
184 }
185
186 pub fn allocate(&self) -> PooledBuffer<'_> {
188 self.total_allocations.fetch_add(1, Ordering::Relaxed);
189
190 let buffer = {
191 let mut free = self.free_list.lock();
192 if let Some(buf) = free.pop() {
193 self.cache_hits.fetch_add(1, Ordering::Relaxed);
194 self.pool_size.fetch_sub(1, Ordering::Relaxed);
195 buf
196 } else {
197 vec![0u8; self.buffer_size]
198 }
199 };
200
201 PooledBuffer {
202 buffer: Some(buffer),
203 pool: self,
204 }
205 }
206
207 fn return_buffer(&self, mut buffer: Vec<u8>) {
209 let mut free = self.free_list.lock();
210 if free.len() < self.max_buffers {
211 buffer.clear();
212 buffer.resize(self.buffer_size, 0);
213 free.push(buffer);
214 self.pool_size.fetch_add(1, Ordering::Relaxed);
215 }
216 }
218
219 pub fn name(&self) -> &str {
221 &self.name
222 }
223
224 pub fn buffer_size(&self) -> usize {
226 self.buffer_size
227 }
228
229 pub fn current_size(&self) -> usize {
231 self.pool_size.load(Ordering::Relaxed)
232 }
233
234 pub fn hit_rate(&self) -> f64 {
236 let total = self.total_allocations.load(Ordering::Relaxed);
237 let hits = self.cache_hits.load(Ordering::Relaxed);
238 if total == 0 {
239 0.0
240 } else {
241 hits as f64 / total as f64
242 }
243 }
244
245 pub fn preallocate(&self, count: usize) {
247 let count = count.min(self.max_buffers);
248 let mut free = self.free_list.lock();
249 for _ in free.len()..count {
250 free.push(vec![0u8; self.buffer_size]);
251 self.pool_size.fetch_add(1, Ordering::Relaxed);
252 }
253 }
254}
255
256pub struct PooledBuffer<'a> {
260 buffer: Option<Vec<u8>>,
261 pool: &'a MemoryPool,
262}
263
264impl<'a> PooledBuffer<'a> {
265 pub fn as_slice(&self) -> &[u8] {
267 self.buffer.as_deref().unwrap_or(&[])
268 }
269
270 pub fn as_mut_slice(&mut self) -> &mut [u8] {
272 self.buffer.as_deref_mut().unwrap_or(&mut [])
273 }
274
275 pub fn len(&self) -> usize {
277 self.buffer.as_ref().map(|b| b.len()).unwrap_or(0)
278 }
279
280 pub fn is_empty(&self) -> bool {
282 self.len() == 0
283 }
284}
285
286impl<'a> Drop for PooledBuffer<'a> {
287 fn drop(&mut self) {
288 if let Some(buffer) = self.buffer.take() {
289 self.pool.return_buffer(buffer);
290 }
291 }
292}
293
294impl<'a> std::ops::Deref for PooledBuffer<'a> {
295 type Target = [u8];
296
297 fn deref(&self) -> &Self::Target {
298 self.as_slice()
299 }
300}
301
302impl<'a> std::ops::DerefMut for PooledBuffer<'a> {
303 fn deref_mut(&mut self) -> &mut Self::Target {
304 self.as_mut_slice()
305 }
306}
307
308pub type SharedMemoryPool = Arc<MemoryPool>;
310
311pub fn create_pool(
313 name: impl Into<String>,
314 buffer_size: usize,
315 max_buffers: usize,
316) -> SharedMemoryPool {
317 Arc::new(MemoryPool::new(name, buffer_size, max_buffers))
318}
319
320#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
329pub enum SizeBucket {
330 Tiny,
332 Small,
334 #[default]
336 Medium,
337 Large,
339 Huge,
341}
342
343impl SizeBucket {
344 pub const ALL: [SizeBucket; 5] = [
346 SizeBucket::Tiny,
347 SizeBucket::Small,
348 SizeBucket::Medium,
349 SizeBucket::Large,
350 SizeBucket::Huge,
351 ];
352
353 pub fn size(&self) -> usize {
355 match self {
356 Self::Tiny => 256,
357 Self::Small => 1024,
358 Self::Medium => 4096,
359 Self::Large => 16384,
360 Self::Huge => 65536,
361 }
362 }
363
364 pub fn for_size(requested: usize) -> Self {
368 if requested <= 256 {
369 Self::Tiny
370 } else if requested <= 1024 {
371 Self::Small
372 } else if requested <= 4096 {
373 Self::Medium
374 } else if requested <= 16384 {
375 Self::Large
376 } else {
377 Self::Huge
378 }
379 }
380
381 pub fn upgrade(&self) -> Self {
383 match self {
384 Self::Tiny => Self::Small,
385 Self::Small => Self::Medium,
386 Self::Medium => Self::Large,
387 Self::Large => Self::Huge,
388 Self::Huge => Self::Huge,
389 }
390 }
391
392 pub fn downgrade(&self) -> Self {
394 match self {
395 Self::Tiny => Self::Tiny,
396 Self::Small => Self::Tiny,
397 Self::Medium => Self::Small,
398 Self::Large => Self::Medium,
399 Self::Huge => Self::Large,
400 }
401 }
402}
403
404impl std::fmt::Display for SizeBucket {
405 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
406 match self {
407 Self::Tiny => write!(f, "Tiny(256B)"),
408 Self::Small => write!(f, "Small(1KB)"),
409 Self::Medium => write!(f, "Medium(4KB)"),
410 Self::Large => write!(f, "Large(16KB)"),
411 Self::Huge => write!(f, "Huge(64KB)"),
412 }
413 }
414}
415
416#[derive(Debug, Clone, Default)]
418pub struct StratifiedPoolStats {
419 pub total_allocations: usize,
421 pub total_hits: usize,
423 pub allocations_per_bucket: std::collections::HashMap<SizeBucket, usize>,
425 pub hits_per_bucket: std::collections::HashMap<SizeBucket, usize>,
427}
428
429impl StratifiedPoolStats {
430 pub fn hit_rate(&self) -> f64 {
432 if self.total_allocations == 0 {
433 0.0
434 } else {
435 self.total_hits as f64 / self.total_allocations as f64
436 }
437 }
438
439 pub fn bucket_hit_rate(&self, bucket: SizeBucket) -> f64 {
441 let allocs = self
442 .allocations_per_bucket
443 .get(&bucket)
444 .copied()
445 .unwrap_or(0);
446 let hits = self.hits_per_bucket.get(&bucket).copied().unwrap_or(0);
447 if allocs == 0 {
448 0.0
449 } else {
450 hits as f64 / allocs as f64
451 }
452 }
453}
454
455pub struct StratifiedMemoryPool {
477 name: String,
478 buckets: std::collections::HashMap<SizeBucket, MemoryPool>,
479 max_buffers_per_bucket: usize,
480 stats: Mutex<StratifiedPoolStats>,
481}
482
483impl StratifiedMemoryPool {
484 pub fn new(name: impl Into<String>) -> Self {
488 Self::with_capacity(name, 16)
489 }
490
491 pub fn with_capacity(name: impl Into<String>, max_buffers_per_bucket: usize) -> Self {
493 let name = name.into();
494 let mut buckets = std::collections::HashMap::new();
495
496 for bucket in SizeBucket::ALL {
497 let pool_name = format!("{}_{}", name, bucket);
498 buckets.insert(
499 bucket,
500 MemoryPool::new(pool_name, bucket.size(), max_buffers_per_bucket),
501 );
502 }
503
504 Self {
505 name,
506 buckets,
507 max_buffers_per_bucket,
508 stats: Mutex::new(StratifiedPoolStats::default()),
509 }
510 }
511
512 pub fn allocate(&self, size: usize) -> StratifiedBuffer<'_> {
516 let bucket = SizeBucket::for_size(size);
517 self.allocate_bucket(bucket)
518 }
519
520 pub fn allocate_bucket(&self, bucket: SizeBucket) -> StratifiedBuffer<'_> {
522 let pool = self
523 .buckets
524 .get(&bucket)
525 .expect("all SizeBucket variants are inserted in new()");
526
527 let was_cached = pool.current_size() > 0;
529 let buffer = pool.allocate();
530
531 {
533 let mut stats = self.stats.lock();
534 stats.total_allocations += 1;
535 *stats.allocations_per_bucket.entry(bucket).or_insert(0) += 1;
536 if was_cached {
537 stats.total_hits += 1;
538 *stats.hits_per_bucket.entry(bucket).or_insert(0) += 1;
539 }
540 }
541
542 StratifiedBuffer {
543 inner: buffer,
544 bucket,
545 pool: self,
546 }
547 }
548
549 pub fn name(&self) -> &str {
551 &self.name
552 }
553
554 pub fn max_buffers_per_bucket(&self) -> usize {
556 self.max_buffers_per_bucket
557 }
558
559 pub fn bucket_size(&self, bucket: SizeBucket) -> usize {
561 self.buckets
562 .get(&bucket)
563 .map(|p| p.current_size())
564 .unwrap_or(0)
565 }
566
567 pub fn total_pooled(&self) -> usize {
569 self.buckets.values().map(|p| p.current_size()).sum()
570 }
571
572 pub fn stats(&self) -> StratifiedPoolStats {
574 self.stats.lock().clone()
575 }
576
577 pub fn preallocate(&self, bucket: SizeBucket, count: usize) {
579 if let Some(pool) = self.buckets.get(&bucket) {
580 pool.preallocate(count);
581 }
582 }
583
584 pub fn preallocate_all(&self, count_per_bucket: usize) {
586 for bucket in SizeBucket::ALL {
587 self.preallocate(bucket, count_per_bucket);
588 }
589 }
590
591 pub fn shrink_to(&self, target_per_bucket: usize) {
595 for pool in self.buckets.values() {
596 let mut free_list = pool.free_list.lock();
597 while free_list.len() > target_per_bucket {
598 free_list.pop();
599 pool.pool_size.fetch_sub(1, Ordering::Relaxed);
600 }
601 }
602 }
603}
604
605pub struct StratifiedBuffer<'a> {
609 inner: PooledBuffer<'a>,
610 bucket: SizeBucket,
611 #[allow(dead_code)]
612 pool: &'a StratifiedMemoryPool,
613}
614
615impl<'a> StratifiedBuffer<'a> {
616 pub fn bucket(&self) -> SizeBucket {
618 self.bucket
619 }
620
621 pub fn capacity(&self) -> usize {
623 self.bucket.size()
624 }
625
626 pub fn as_slice(&self) -> &[u8] {
628 self.inner.as_slice()
629 }
630
631 pub fn as_mut_slice(&mut self) -> &mut [u8] {
633 self.inner.as_mut_slice()
634 }
635
636 pub fn len(&self) -> usize {
638 self.inner.len()
639 }
640
641 pub fn is_empty(&self) -> bool {
643 self.inner.is_empty()
644 }
645}
646
647impl<'a> std::ops::Deref for StratifiedBuffer<'a> {
648 type Target = [u8];
649
650 fn deref(&self) -> &Self::Target {
651 self.as_slice()
652 }
653}
654
655impl<'a> std::ops::DerefMut for StratifiedBuffer<'a> {
656 fn deref_mut(&mut self) -> &mut Self::Target {
657 self.as_mut_slice()
658 }
659}
660
661pub type SharedStratifiedPool = Arc<StratifiedMemoryPool>;
663
664pub fn create_stratified_pool(name: impl Into<String>) -> SharedStratifiedPool {
666 Arc::new(StratifiedMemoryPool::new(name))
667}
668
669pub fn create_stratified_pool_with_capacity(
671 name: impl Into<String>,
672 max_buffers_per_bucket: usize,
673) -> SharedStratifiedPool {
674 Arc::new(StratifiedMemoryPool::with_capacity(
675 name,
676 max_buffers_per_bucket,
677 ))
678}
679
680use crate::observability::MemoryPressureLevel;
685
686pub type PressureCallback = Box<dyn Fn(MemoryPressureLevel) + Send + Sync>;
688
689pub enum PressureReaction {
694 None,
696 Shrink {
701 target_utilization: f64,
703 },
704 Callback(PressureCallback),
706}
707
708impl std::fmt::Debug for PressureReaction {
709 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
710 match self {
711 Self::None => write!(f, "PressureReaction::None"),
712 Self::Shrink { target_utilization } => {
713 write!(
714 f,
715 "PressureReaction::Shrink {{ target_utilization: {} }}",
716 target_utilization
717 )
718 }
719 Self::Callback(_) => write!(f, "PressureReaction::Callback(<fn>)"),
720 }
721 }
722}
723
724pub struct PressureHandler {
728 reaction: PressureReaction,
730 current_level: Mutex<MemoryPressureLevel>,
732}
733
734impl PressureHandler {
735 pub fn new(reaction: PressureReaction) -> Self {
737 Self {
738 reaction,
739 current_level: Mutex::new(MemoryPressureLevel::Normal),
740 }
741 }
742
743 pub fn no_reaction() -> Self {
745 Self::new(PressureReaction::None)
746 }
747
748 pub fn shrink_to(target_utilization: f64) -> Self {
750 Self::new(PressureReaction::Shrink {
751 target_utilization: target_utilization.clamp(0.0, 1.0),
752 })
753 }
754
755 pub fn with_callback<F>(callback: F) -> Self
757 where
758 F: Fn(MemoryPressureLevel) + Send + Sync + 'static,
759 {
760 Self::new(PressureReaction::Callback(Box::new(callback)))
761 }
762
763 pub fn current_level(&self) -> MemoryPressureLevel {
765 *self.current_level.lock()
766 }
767
768 pub fn on_pressure_change(
772 &self,
773 new_level: MemoryPressureLevel,
774 max_per_bucket: usize,
775 ) -> Option<usize> {
776 let old_level = {
777 let mut current = self.current_level.lock();
778 let old = *current;
779 *current = new_level;
780 old
781 };
782
783 if !Self::is_higher_pressure(new_level, old_level) {
785 return None;
786 }
787
788 match &self.reaction {
789 PressureReaction::None => None,
790 PressureReaction::Shrink { target_utilization } => {
791 let pressure_factor = Self::pressure_severity(new_level);
793 let adjusted_target = target_utilization * (1.0 - pressure_factor);
794 let target_count = ((max_per_bucket as f64) * adjusted_target) as usize;
795 Some(target_count.max(1)) }
797 PressureReaction::Callback(callback) => {
798 callback(new_level);
799 None
800 }
801 }
802 }
803
804 fn is_higher_pressure(new: MemoryPressureLevel, old: MemoryPressureLevel) -> bool {
806 Self::pressure_ordinal(new) > Self::pressure_ordinal(old)
807 }
808
809 fn pressure_ordinal(level: MemoryPressureLevel) -> u8 {
811 match level {
812 MemoryPressureLevel::Normal => 0,
813 MemoryPressureLevel::Elevated => 1,
814 MemoryPressureLevel::Warning => 2,
815 MemoryPressureLevel::Critical => 3,
816 MemoryPressureLevel::OutOfMemory => 4,
817 }
818 }
819
820 fn pressure_severity(level: MemoryPressureLevel) -> f64 {
822 match level {
823 MemoryPressureLevel::Normal => 0.0,
824 MemoryPressureLevel::Elevated => 0.2,
825 MemoryPressureLevel::Warning => 0.5,
826 MemoryPressureLevel::Critical => 0.8,
827 MemoryPressureLevel::OutOfMemory => 1.0,
828 }
829 }
830}
831
832impl std::fmt::Debug for PressureHandler {
833 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
834 f.debug_struct("PressureHandler")
835 .field("reaction", &self.reaction)
836 .field("current_level", &self.current_level())
837 .finish()
838 }
839}
840
841pub trait PressureAwarePool {
843 fn handle_pressure(&self, level: MemoryPressureLevel) -> bool;
847
848 fn pressure_level(&self) -> MemoryPressureLevel;
850}
851
852pub mod align {
854 pub const CACHE_LINE_SIZE: usize = 64;
856
857 pub const GPU_CACHE_LINE_SIZE: usize = 128;
859
860 #[inline]
862 pub const fn align_up(value: usize, alignment: usize) -> usize {
863 let mask = alignment - 1;
864 (value + mask) & !mask
865 }
866
867 #[inline]
869 pub const fn align_down(value: usize, alignment: usize) -> usize {
870 let mask = alignment - 1;
871 value & !mask
872 }
873
874 #[inline]
876 pub const fn is_aligned(value: usize, alignment: usize) -> bool {
877 value & (alignment - 1) == 0
878 }
879
880 #[inline]
882 pub const fn padding_for(offset: usize, alignment: usize) -> usize {
883 let misalignment = offset & (alignment - 1);
884 if misalignment == 0 {
885 0
886 } else {
887 alignment - misalignment
888 }
889 }
890}
891
892#[cfg(test)]
893mod tests {
894 use super::*;
895
896 #[test]
897 fn test_pinned_memory() {
898 let mut mem = PinnedMemory::<f32>::new(1024).unwrap();
899 assert_eq!(mem.len(), 1024);
900 assert_eq!(mem.size_bytes(), 1024 * 4);
901
902 let slice = mem.as_mut_slice();
904 for (i, v) in slice.iter_mut().enumerate() {
905 *v = i as f32;
906 }
907
908 assert_eq!(mem.as_slice()[42], 42.0);
910 }
911
912 #[test]
913 fn test_pinned_memory_from_slice() {
914 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
915 let mem = PinnedMemory::from_slice(&data).unwrap();
916 assert_eq!(mem.as_slice(), &data[..]);
917 }
918
919 #[test]
920 fn test_memory_pool() {
921 let pool = MemoryPool::new("test", 1024, 10);
922
923 let buf1 = pool.allocate();
925 assert_eq!(buf1.len(), 1024);
926 drop(buf1);
927
928 let _buf2 = pool.allocate();
930 assert_eq!(pool.hit_rate(), 0.5); }
932
933 #[test]
934 fn test_pool_preallocate() {
935 let pool = MemoryPool::new("test", 1024, 10);
936 pool.preallocate(5);
937 assert_eq!(pool.current_size(), 5);
938
939 for _ in 0..5 {
941 let _ = pool.allocate();
942 }
943 assert_eq!(pool.hit_rate(), 1.0);
944 }
945
946 #[test]
947 fn test_align_up() {
948 use align::*;
949
950 assert_eq!(align_up(0, 64), 0);
951 assert_eq!(align_up(1, 64), 64);
952 assert_eq!(align_up(64, 64), 64);
953 assert_eq!(align_up(65, 64), 128);
954 }
955
956 #[test]
957 fn test_is_aligned() {
958 use align::*;
959
960 assert!(is_aligned(0, 64));
961 assert!(is_aligned(64, 64));
962 assert!(is_aligned(128, 64));
963 assert!(!is_aligned(1, 64));
964 assert!(!is_aligned(63, 64));
965 }
966
967 #[test]
968 fn test_padding_for() {
969 use align::*;
970
971 assert_eq!(padding_for(0, 64), 0);
972 assert_eq!(padding_for(1, 64), 63);
973 assert_eq!(padding_for(63, 64), 1);
974 assert_eq!(padding_for(64, 64), 0);
975 }
976
977 #[test]
982 fn test_size_bucket_sizes() {
983 assert_eq!(SizeBucket::Tiny.size(), 256);
984 assert_eq!(SizeBucket::Small.size(), 1024);
985 assert_eq!(SizeBucket::Medium.size(), 4096);
986 assert_eq!(SizeBucket::Large.size(), 16384);
987 assert_eq!(SizeBucket::Huge.size(), 65536);
988 }
989
990 #[test]
991 fn test_size_bucket_selection() {
992 assert_eq!(SizeBucket::for_size(0), SizeBucket::Tiny);
994 assert_eq!(SizeBucket::for_size(256), SizeBucket::Tiny);
995 assert_eq!(SizeBucket::for_size(257), SizeBucket::Small);
996 assert_eq!(SizeBucket::for_size(1024), SizeBucket::Small);
997 assert_eq!(SizeBucket::for_size(1025), SizeBucket::Medium);
998 assert_eq!(SizeBucket::for_size(4096), SizeBucket::Medium);
999 assert_eq!(SizeBucket::for_size(4097), SizeBucket::Large);
1000 assert_eq!(SizeBucket::for_size(16384), SizeBucket::Large);
1001 assert_eq!(SizeBucket::for_size(16385), SizeBucket::Huge);
1002 assert_eq!(SizeBucket::for_size(100000), SizeBucket::Huge);
1003 }
1004
1005 #[test]
1006 fn test_size_bucket_upgrade_downgrade() {
1007 assert_eq!(SizeBucket::Tiny.upgrade(), SizeBucket::Small);
1008 assert_eq!(SizeBucket::Small.upgrade(), SizeBucket::Medium);
1009 assert_eq!(SizeBucket::Medium.upgrade(), SizeBucket::Large);
1010 assert_eq!(SizeBucket::Large.upgrade(), SizeBucket::Huge);
1011 assert_eq!(SizeBucket::Huge.upgrade(), SizeBucket::Huge); assert_eq!(SizeBucket::Tiny.downgrade(), SizeBucket::Tiny); assert_eq!(SizeBucket::Small.downgrade(), SizeBucket::Tiny);
1015 assert_eq!(SizeBucket::Medium.downgrade(), SizeBucket::Small);
1016 assert_eq!(SizeBucket::Large.downgrade(), SizeBucket::Medium);
1017 assert_eq!(SizeBucket::Huge.downgrade(), SizeBucket::Large);
1018 }
1019
1020 #[test]
1021 fn test_stratified_pool_allocation() {
1022 let pool = StratifiedMemoryPool::new("test");
1023
1024 let buf1 = pool.allocate(100); let buf2 = pool.allocate(500); let buf3 = pool.allocate(2000); assert_eq!(buf1.bucket(), SizeBucket::Tiny);
1030 assert_eq!(buf2.bucket(), SizeBucket::Small);
1031 assert_eq!(buf3.bucket(), SizeBucket::Medium);
1032
1033 assert_eq!(buf1.capacity(), 256);
1035 assert_eq!(buf2.capacity(), 1024);
1036 assert_eq!(buf3.capacity(), 4096);
1037 }
1038
1039 #[test]
1040 fn test_stratified_pool_reuse() {
1041 let pool = StratifiedMemoryPool::new("test");
1042
1043 {
1045 let _buf = pool.allocate(100);
1046 }
1047 {
1051 let _buf = pool.allocate(100);
1052 }
1053
1054 let stats = pool.stats();
1055 assert_eq!(stats.total_allocations, 2);
1056 assert_eq!(stats.total_hits, 1);
1057 assert!((stats.hit_rate() - 0.5).abs() < 0.001);
1058 }
1059
1060 #[test]
1061 fn test_stratified_pool_stats_per_bucket() {
1062 let pool = StratifiedMemoryPool::new("test");
1063
1064 let _buf1 = pool.allocate(100); let _buf2 = pool.allocate(500); let _buf3 = pool.allocate(100); let stats = pool.stats();
1070 assert_eq!(stats.total_allocations, 3);
1071 assert_eq!(
1072 stats.allocations_per_bucket.get(&SizeBucket::Tiny),
1073 Some(&2)
1074 );
1075 assert_eq!(
1076 stats.allocations_per_bucket.get(&SizeBucket::Small),
1077 Some(&1)
1078 );
1079 }
1080
1081 #[test]
1082 fn test_stratified_pool_preallocate() {
1083 let pool = StratifiedMemoryPool::new("test");
1084
1085 pool.preallocate(SizeBucket::Medium, 5);
1086 assert_eq!(pool.bucket_size(SizeBucket::Medium), 5);
1087 assert_eq!(pool.bucket_size(SizeBucket::Tiny), 0);
1088
1089 for _ in 0..5 {
1091 let _buf = pool.allocate(2000);
1092 }
1093
1094 let stats = pool.stats();
1095 assert_eq!(stats.hits_per_bucket.get(&SizeBucket::Medium), Some(&5));
1096 }
1097
1098 #[test]
1099 fn test_stratified_pool_shrink() {
1100 let pool = StratifiedMemoryPool::new("test");
1101
1102 pool.preallocate_all(10);
1104 assert_eq!(pool.total_pooled(), 50); pool.shrink_to(2);
1107 assert_eq!(pool.total_pooled(), 10); }
1109
1110 #[test]
1111 fn test_stratified_buffer_deref() {
1112 let pool = StratifiedMemoryPool::new("test");
1113
1114 let mut buf = pool.allocate(100);
1115
1116 buf[0] = 42;
1118 buf[1] = 43;
1119
1120 assert_eq!(buf[0], 42);
1122 assert_eq!(buf[1], 43);
1123 }
1124
1125 #[test]
1130 fn test_pressure_handler_no_reaction() {
1131 let handler = PressureHandler::no_reaction();
1132 assert_eq!(handler.current_level(), MemoryPressureLevel::Normal);
1133
1134 let result = handler.on_pressure_change(MemoryPressureLevel::Critical, 10);
1135 assert!(result.is_none());
1136 }
1137
1138 #[test]
1139 fn test_pressure_handler_shrink() {
1140 let handler = PressureHandler::shrink_to(0.5);
1141
1142 let result = handler.on_pressure_change(MemoryPressureLevel::Critical, 10);
1144 assert!(result.is_some());
1145 assert!(result.unwrap() >= 1);
1148 }
1149
1150 #[test]
1151 fn test_pressure_handler_callback() {
1152 use std::sync::atomic::{AtomicBool, Ordering};
1153 use std::sync::Arc;
1154
1155 let called = Arc::new(AtomicBool::new(false));
1156 let called_clone = called.clone();
1157
1158 let handler = PressureHandler::with_callback(move |level| {
1159 if level == MemoryPressureLevel::Critical {
1160 called_clone.store(true, Ordering::SeqCst);
1161 }
1162 });
1163
1164 handler.on_pressure_change(MemoryPressureLevel::Critical, 10);
1165 assert!(called.load(Ordering::SeqCst));
1166 }
1167
1168 #[test]
1169 fn test_pressure_handler_only_reacts_to_increase() {
1170 let handler = PressureHandler::shrink_to(0.5);
1171
1172 handler.on_pressure_change(MemoryPressureLevel::Critical, 10);
1174
1175 let result = handler.on_pressure_change(MemoryPressureLevel::Normal, 10);
1177 assert!(result.is_none());
1178 }
1179
1180 #[test]
1181 fn test_pressure_handler_level_tracking() {
1182 let handler = PressureHandler::no_reaction();
1183
1184 assert_eq!(handler.current_level(), MemoryPressureLevel::Normal);
1185
1186 handler.on_pressure_change(MemoryPressureLevel::Warning, 10);
1187 assert_eq!(handler.current_level(), MemoryPressureLevel::Warning);
1188
1189 handler.on_pressure_change(MemoryPressureLevel::Critical, 10);
1190 assert_eq!(handler.current_level(), MemoryPressureLevel::Critical);
1191 }
1192
1193 #[test]
1194 fn test_pressure_reaction_debug() {
1195 let none = PressureReaction::None;
1196 assert!(format!("{:?}", none).contains("None"));
1197
1198 let shrink = PressureReaction::Shrink {
1199 target_utilization: 0.5,
1200 };
1201 assert!(format!("{:?}", shrink).contains("0.5"));
1202
1203 let callback = PressureReaction::Callback(Box::new(|_| {}));
1204 assert!(format!("{:?}", callback).contains("Callback"));
1205 }
1206
1207 #[test]
1208 fn test_pressure_handler_debug() {
1209 let handler = PressureHandler::shrink_to(0.3);
1210 let debug_str = format!("{:?}", handler);
1211 assert!(debug_str.contains("PressureHandler"));
1212 assert!(debug_str.contains("Shrink"));
1213 }
1214
1215 #[test]
1216 fn test_pressure_severity_values() {
1217 let normal = PressureHandler::pressure_severity(MemoryPressureLevel::Normal);
1219 let elevated = PressureHandler::pressure_severity(MemoryPressureLevel::Elevated);
1220 let warning = PressureHandler::pressure_severity(MemoryPressureLevel::Warning);
1221 let critical = PressureHandler::pressure_severity(MemoryPressureLevel::Critical);
1222 let oom = PressureHandler::pressure_severity(MemoryPressureLevel::OutOfMemory);
1223
1224 assert!(normal < elevated);
1225 assert!(elevated < warning);
1226 assert!(warning < critical);
1227 assert!(critical < oom);
1228 assert!(oom <= 1.0);
1229 }
1230}