1use std::alloc::{alloc, dealloc, Layout};
7use std::marker::PhantomData;
8use std::ptr::NonNull;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::Arc;
11
12use parking_lot::Mutex;
13
14use crate::error::{Result, RingKernelError};
15
16pub trait GpuBuffer: Send + Sync {
18 fn size(&self) -> usize;
20
21 fn device_ptr(&self) -> usize;
23
24 fn copy_from_host(&self, data: &[u8]) -> Result<()>;
26
27 fn copy_to_host(&self, data: &mut [u8]) -> Result<()>;
29}
30
31pub trait DeviceMemory: Send + Sync {
33 fn allocate(&self, size: usize) -> Result<Box<dyn GpuBuffer>>;
35
36 fn allocate_aligned(&self, size: usize, alignment: usize) -> Result<Box<dyn GpuBuffer>>;
38
39 fn total_memory(&self) -> usize;
41
42 fn free_memory(&self) -> usize;
44}
45
46pub struct PinnedMemory<T: Copy> {
51 ptr: NonNull<T>,
52 len: usize,
53 layout: Layout,
54 _marker: PhantomData<T>,
55}
56
57impl<T: Copy> PinnedMemory<T> {
58 pub fn new(count: usize) -> Result<Self> {
65 if count == 0 {
66 return Err(RingKernelError::InvalidConfig(
67 "Cannot allocate zero-sized buffer".to_string(),
68 ));
69 }
70
71 let layout =
72 Layout::array::<T>(count).map_err(|_| RingKernelError::HostAllocationFailed {
73 size: count * std::mem::size_of::<T>(),
74 })?;
75
76 let ptr = unsafe { alloc(layout) };
79
80 if ptr.is_null() {
81 return Err(RingKernelError::HostAllocationFailed {
82 size: layout.size(),
83 });
84 }
85
86 Ok(Self {
87 ptr: NonNull::new(ptr as *mut T).unwrap(),
88 len: count,
89 layout,
90 _marker: PhantomData,
91 })
92 }
93
94 pub fn from_slice(data: &[T]) -> Result<Self> {
96 let mut mem = Self::new(data.len())?;
97 mem.as_mut_slice().copy_from_slice(data);
98 Ok(mem)
99 }
100
101 pub fn as_slice(&self) -> &[T] {
103 unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
104 }
105
106 pub fn as_mut_slice(&mut self) -> &mut [T] {
108 unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
109 }
110
111 pub fn as_ptr(&self) -> *const T {
113 self.ptr.as_ptr()
114 }
115
116 pub fn as_mut_ptr(&mut self) -> *mut T {
118 self.ptr.as_ptr()
119 }
120
121 pub fn len(&self) -> usize {
123 self.len
124 }
125
126 pub fn is_empty(&self) -> bool {
128 self.len == 0
129 }
130
131 pub fn size_bytes(&self) -> usize {
133 self.len * std::mem::size_of::<T>()
134 }
135}
136
137impl<T: Copy> Drop for PinnedMemory<T> {
138 fn drop(&mut self) {
139 unsafe {
141 dealloc(self.ptr.as_ptr() as *mut u8, self.layout);
142 }
143 }
144}
145
146unsafe impl<T: Copy + Send> Send for PinnedMemory<T> {}
148unsafe impl<T: Copy + Sync> Sync for PinnedMemory<T> {}
149
150pub struct MemoryPool {
155 name: String,
157 buffer_size: usize,
159 max_buffers: usize,
161 free_list: Mutex<Vec<Vec<u8>>>,
163 total_allocations: AtomicUsize,
165 cache_hits: AtomicUsize,
167 pool_size: AtomicUsize,
169}
170
171impl MemoryPool {
172 pub fn new(name: impl Into<String>, buffer_size: usize, max_buffers: usize) -> Self {
174 Self {
175 name: name.into(),
176 buffer_size,
177 max_buffers,
178 free_list: Mutex::new(Vec::with_capacity(max_buffers)),
179 total_allocations: AtomicUsize::new(0),
180 cache_hits: AtomicUsize::new(0),
181 pool_size: AtomicUsize::new(0),
182 }
183 }
184
185 pub fn allocate(&self) -> PooledBuffer<'_> {
187 self.total_allocations.fetch_add(1, Ordering::Relaxed);
188
189 let buffer = {
190 let mut free = self.free_list.lock();
191 if let Some(buf) = free.pop() {
192 self.cache_hits.fetch_add(1, Ordering::Relaxed);
193 self.pool_size.fetch_sub(1, Ordering::Relaxed);
194 buf
195 } else {
196 vec![0u8; self.buffer_size]
197 }
198 };
199
200 PooledBuffer {
201 buffer: Some(buffer),
202 pool: self,
203 }
204 }
205
206 fn return_buffer(&self, mut buffer: Vec<u8>) {
208 let mut free = self.free_list.lock();
209 if free.len() < self.max_buffers {
210 buffer.clear();
211 buffer.resize(self.buffer_size, 0);
212 free.push(buffer);
213 self.pool_size.fetch_add(1, Ordering::Relaxed);
214 }
215 }
217
218 pub fn name(&self) -> &str {
220 &self.name
221 }
222
223 pub fn buffer_size(&self) -> usize {
225 self.buffer_size
226 }
227
228 pub fn current_size(&self) -> usize {
230 self.pool_size.load(Ordering::Relaxed)
231 }
232
233 pub fn hit_rate(&self) -> f64 {
235 let total = self.total_allocations.load(Ordering::Relaxed);
236 let hits = self.cache_hits.load(Ordering::Relaxed);
237 if total == 0 {
238 0.0
239 } else {
240 hits as f64 / total as f64
241 }
242 }
243
244 pub fn preallocate(&self, count: usize) {
246 let count = count.min(self.max_buffers);
247 let mut free = self.free_list.lock();
248 for _ in free.len()..count {
249 free.push(vec![0u8; self.buffer_size]);
250 self.pool_size.fetch_add(1, Ordering::Relaxed);
251 }
252 }
253}
254
255pub struct PooledBuffer<'a> {
259 buffer: Option<Vec<u8>>,
260 pool: &'a MemoryPool,
261}
262
263impl<'a> PooledBuffer<'a> {
264 pub fn as_slice(&self) -> &[u8] {
266 self.buffer.as_deref().unwrap_or(&[])
267 }
268
269 pub fn as_mut_slice(&mut self) -> &mut [u8] {
271 self.buffer.as_deref_mut().unwrap_or(&mut [])
272 }
273
274 pub fn len(&self) -> usize {
276 self.buffer.as_ref().map(|b| b.len()).unwrap_or(0)
277 }
278
279 pub fn is_empty(&self) -> bool {
281 self.len() == 0
282 }
283}
284
285impl<'a> Drop for PooledBuffer<'a> {
286 fn drop(&mut self) {
287 if let Some(buffer) = self.buffer.take() {
288 self.pool.return_buffer(buffer);
289 }
290 }
291}
292
293impl<'a> std::ops::Deref for PooledBuffer<'a> {
294 type Target = [u8];
295
296 fn deref(&self) -> &Self::Target {
297 self.as_slice()
298 }
299}
300
301impl<'a> std::ops::DerefMut for PooledBuffer<'a> {
302 fn deref_mut(&mut self) -> &mut Self::Target {
303 self.as_mut_slice()
304 }
305}
306
307pub type SharedMemoryPool = Arc<MemoryPool>;
309
310pub fn create_pool(
312 name: impl Into<String>,
313 buffer_size: usize,
314 max_buffers: usize,
315) -> SharedMemoryPool {
316 Arc::new(MemoryPool::new(name, buffer_size, max_buffers))
317}
318
319#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
328pub enum SizeBucket {
329 Tiny,
331 Small,
333 #[default]
335 Medium,
336 Large,
338 Huge,
340}
341
342impl SizeBucket {
343 pub const ALL: [SizeBucket; 5] = [
345 SizeBucket::Tiny,
346 SizeBucket::Small,
347 SizeBucket::Medium,
348 SizeBucket::Large,
349 SizeBucket::Huge,
350 ];
351
352 pub fn size(&self) -> usize {
354 match self {
355 Self::Tiny => 256,
356 Self::Small => 1024,
357 Self::Medium => 4096,
358 Self::Large => 16384,
359 Self::Huge => 65536,
360 }
361 }
362
363 pub fn for_size(requested: usize) -> Self {
367 if requested <= 256 {
368 Self::Tiny
369 } else if requested <= 1024 {
370 Self::Small
371 } else if requested <= 4096 {
372 Self::Medium
373 } else if requested <= 16384 {
374 Self::Large
375 } else {
376 Self::Huge
377 }
378 }
379
380 pub fn upgrade(&self) -> Self {
382 match self {
383 Self::Tiny => Self::Small,
384 Self::Small => Self::Medium,
385 Self::Medium => Self::Large,
386 Self::Large => Self::Huge,
387 Self::Huge => Self::Huge,
388 }
389 }
390
391 pub fn downgrade(&self) -> Self {
393 match self {
394 Self::Tiny => Self::Tiny,
395 Self::Small => Self::Tiny,
396 Self::Medium => Self::Small,
397 Self::Large => Self::Medium,
398 Self::Huge => Self::Large,
399 }
400 }
401}
402
403impl std::fmt::Display for SizeBucket {
404 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
405 match self {
406 Self::Tiny => write!(f, "Tiny(256B)"),
407 Self::Small => write!(f, "Small(1KB)"),
408 Self::Medium => write!(f, "Medium(4KB)"),
409 Self::Large => write!(f, "Large(16KB)"),
410 Self::Huge => write!(f, "Huge(64KB)"),
411 }
412 }
413}
414
415#[derive(Debug, Clone, Default)]
417pub struct StratifiedPoolStats {
418 pub total_allocations: usize,
420 pub total_hits: usize,
422 pub allocations_per_bucket: std::collections::HashMap<SizeBucket, usize>,
424 pub hits_per_bucket: std::collections::HashMap<SizeBucket, usize>,
426}
427
428impl StratifiedPoolStats {
429 pub fn hit_rate(&self) -> f64 {
431 if self.total_allocations == 0 {
432 0.0
433 } else {
434 self.total_hits as f64 / self.total_allocations as f64
435 }
436 }
437
438 pub fn bucket_hit_rate(&self, bucket: SizeBucket) -> f64 {
440 let allocs = self
441 .allocations_per_bucket
442 .get(&bucket)
443 .copied()
444 .unwrap_or(0);
445 let hits = self.hits_per_bucket.get(&bucket).copied().unwrap_or(0);
446 if allocs == 0 {
447 0.0
448 } else {
449 hits as f64 / allocs as f64
450 }
451 }
452}
453
454pub struct StratifiedMemoryPool {
476 name: String,
477 buckets: std::collections::HashMap<SizeBucket, MemoryPool>,
478 max_buffers_per_bucket: usize,
479 stats: Mutex<StratifiedPoolStats>,
480}
481
482impl StratifiedMemoryPool {
483 pub fn new(name: impl Into<String>) -> Self {
487 Self::with_capacity(name, 16)
488 }
489
490 pub fn with_capacity(name: impl Into<String>, max_buffers_per_bucket: usize) -> Self {
492 let name = name.into();
493 let mut buckets = std::collections::HashMap::new();
494
495 for bucket in SizeBucket::ALL {
496 let pool_name = format!("{}_{}", name, bucket);
497 buckets.insert(
498 bucket,
499 MemoryPool::new(pool_name, bucket.size(), max_buffers_per_bucket),
500 );
501 }
502
503 Self {
504 name,
505 buckets,
506 max_buffers_per_bucket,
507 stats: Mutex::new(StratifiedPoolStats::default()),
508 }
509 }
510
511 pub fn allocate(&self, size: usize) -> StratifiedBuffer<'_> {
515 let bucket = SizeBucket::for_size(size);
516 self.allocate_bucket(bucket)
517 }
518
519 pub fn allocate_bucket(&self, bucket: SizeBucket) -> StratifiedBuffer<'_> {
521 let pool = self.buckets.get(&bucket).expect("bucket pool exists");
522
523 let was_cached = pool.current_size() > 0;
525 let buffer = pool.allocate();
526
527 {
529 let mut stats = self.stats.lock();
530 stats.total_allocations += 1;
531 *stats.allocations_per_bucket.entry(bucket).or_insert(0) += 1;
532 if was_cached {
533 stats.total_hits += 1;
534 *stats.hits_per_bucket.entry(bucket).or_insert(0) += 1;
535 }
536 }
537
538 StratifiedBuffer {
539 inner: buffer,
540 bucket,
541 pool: self,
542 }
543 }
544
545 pub fn name(&self) -> &str {
547 &self.name
548 }
549
550 pub fn max_buffers_per_bucket(&self) -> usize {
552 self.max_buffers_per_bucket
553 }
554
555 pub fn bucket_size(&self, bucket: SizeBucket) -> usize {
557 self.buckets
558 .get(&bucket)
559 .map(|p| p.current_size())
560 .unwrap_or(0)
561 }
562
563 pub fn total_pooled(&self) -> usize {
565 self.buckets.values().map(|p| p.current_size()).sum()
566 }
567
568 pub fn stats(&self) -> StratifiedPoolStats {
570 self.stats.lock().clone()
571 }
572
573 pub fn preallocate(&self, bucket: SizeBucket, count: usize) {
575 if let Some(pool) = self.buckets.get(&bucket) {
576 pool.preallocate(count);
577 }
578 }
579
580 pub fn preallocate_all(&self, count_per_bucket: usize) {
582 for bucket in SizeBucket::ALL {
583 self.preallocate(bucket, count_per_bucket);
584 }
585 }
586
587 pub fn shrink_to(&self, target_per_bucket: usize) {
591 for pool in self.buckets.values() {
592 let mut free_list = pool.free_list.lock();
593 while free_list.len() > target_per_bucket {
594 free_list.pop();
595 pool.pool_size.fetch_sub(1, Ordering::Relaxed);
596 }
597 }
598 }
599}
600
601pub struct StratifiedBuffer<'a> {
605 inner: PooledBuffer<'a>,
606 bucket: SizeBucket,
607 #[allow(dead_code)]
608 pool: &'a StratifiedMemoryPool,
609}
610
611impl<'a> StratifiedBuffer<'a> {
612 pub fn bucket(&self) -> SizeBucket {
614 self.bucket
615 }
616
617 pub fn capacity(&self) -> usize {
619 self.bucket.size()
620 }
621
622 pub fn as_slice(&self) -> &[u8] {
624 self.inner.as_slice()
625 }
626
627 pub fn as_mut_slice(&mut self) -> &mut [u8] {
629 self.inner.as_mut_slice()
630 }
631
632 pub fn len(&self) -> usize {
634 self.inner.len()
635 }
636
637 pub fn is_empty(&self) -> bool {
639 self.inner.is_empty()
640 }
641}
642
643impl<'a> std::ops::Deref for StratifiedBuffer<'a> {
644 type Target = [u8];
645
646 fn deref(&self) -> &Self::Target {
647 self.as_slice()
648 }
649}
650
651impl<'a> std::ops::DerefMut for StratifiedBuffer<'a> {
652 fn deref_mut(&mut self) -> &mut Self::Target {
653 self.as_mut_slice()
654 }
655}
656
657pub type SharedStratifiedPool = Arc<StratifiedMemoryPool>;
659
660pub fn create_stratified_pool(name: impl Into<String>) -> SharedStratifiedPool {
662 Arc::new(StratifiedMemoryPool::new(name))
663}
664
665pub fn create_stratified_pool_with_capacity(
667 name: impl Into<String>,
668 max_buffers_per_bucket: usize,
669) -> SharedStratifiedPool {
670 Arc::new(StratifiedMemoryPool::with_capacity(
671 name,
672 max_buffers_per_bucket,
673 ))
674}
675
676use crate::observability::MemoryPressureLevel;
681
682pub type PressureCallback = Box<dyn Fn(MemoryPressureLevel) + Send + Sync>;
684
685pub enum PressureReaction {
690 None,
692 Shrink {
697 target_utilization: f64,
699 },
700 Callback(PressureCallback),
702}
703
704impl std::fmt::Debug for PressureReaction {
705 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
706 match self {
707 Self::None => write!(f, "PressureReaction::None"),
708 Self::Shrink { target_utilization } => {
709 write!(
710 f,
711 "PressureReaction::Shrink {{ target_utilization: {} }}",
712 target_utilization
713 )
714 }
715 Self::Callback(_) => write!(f, "PressureReaction::Callback(<fn>)"),
716 }
717 }
718}
719
720pub struct PressureHandler {
724 reaction: PressureReaction,
726 current_level: Mutex<MemoryPressureLevel>,
728}
729
730impl PressureHandler {
731 pub fn new(reaction: PressureReaction) -> Self {
733 Self {
734 reaction,
735 current_level: Mutex::new(MemoryPressureLevel::Normal),
736 }
737 }
738
739 pub fn no_reaction() -> Self {
741 Self::new(PressureReaction::None)
742 }
743
744 pub fn shrink_to(target_utilization: f64) -> Self {
746 Self::new(PressureReaction::Shrink {
747 target_utilization: target_utilization.clamp(0.0, 1.0),
748 })
749 }
750
751 pub fn with_callback<F>(callback: F) -> Self
753 where
754 F: Fn(MemoryPressureLevel) + Send + Sync + 'static,
755 {
756 Self::new(PressureReaction::Callback(Box::new(callback)))
757 }
758
759 pub fn current_level(&self) -> MemoryPressureLevel {
761 *self.current_level.lock()
762 }
763
764 pub fn on_pressure_change(
768 &self,
769 new_level: MemoryPressureLevel,
770 max_per_bucket: usize,
771 ) -> Option<usize> {
772 let old_level = {
773 let mut current = self.current_level.lock();
774 let old = *current;
775 *current = new_level;
776 old
777 };
778
779 if !Self::is_higher_pressure(new_level, old_level) {
781 return None;
782 }
783
784 match &self.reaction {
785 PressureReaction::None => None,
786 PressureReaction::Shrink { target_utilization } => {
787 let pressure_factor = Self::pressure_severity(new_level);
789 let adjusted_target = target_utilization * (1.0 - pressure_factor);
790 let target_count = ((max_per_bucket as f64) * adjusted_target) as usize;
791 Some(target_count.max(1)) }
793 PressureReaction::Callback(callback) => {
794 callback(new_level);
795 None
796 }
797 }
798 }
799
800 fn is_higher_pressure(new: MemoryPressureLevel, old: MemoryPressureLevel) -> bool {
802 Self::pressure_ordinal(new) > Self::pressure_ordinal(old)
803 }
804
805 fn pressure_ordinal(level: MemoryPressureLevel) -> u8 {
807 match level {
808 MemoryPressureLevel::Normal => 0,
809 MemoryPressureLevel::Elevated => 1,
810 MemoryPressureLevel::Warning => 2,
811 MemoryPressureLevel::Critical => 3,
812 MemoryPressureLevel::OutOfMemory => 4,
813 }
814 }
815
816 fn pressure_severity(level: MemoryPressureLevel) -> f64 {
818 match level {
819 MemoryPressureLevel::Normal => 0.0,
820 MemoryPressureLevel::Elevated => 0.2,
821 MemoryPressureLevel::Warning => 0.5,
822 MemoryPressureLevel::Critical => 0.8,
823 MemoryPressureLevel::OutOfMemory => 1.0,
824 }
825 }
826}
827
828impl std::fmt::Debug for PressureHandler {
829 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
830 f.debug_struct("PressureHandler")
831 .field("reaction", &self.reaction)
832 .field("current_level", &self.current_level())
833 .finish()
834 }
835}
836
837pub trait PressureAwarePool {
839 fn handle_pressure(&self, level: MemoryPressureLevel) -> bool;
843
844 fn pressure_level(&self) -> MemoryPressureLevel;
846}
847
848pub mod align {
850 pub const CACHE_LINE_SIZE: usize = 64;
852
853 pub const GPU_CACHE_LINE_SIZE: usize = 128;
855
856 #[inline]
858 pub const fn align_up(value: usize, alignment: usize) -> usize {
859 let mask = alignment - 1;
860 (value + mask) & !mask
861 }
862
863 #[inline]
865 pub const fn align_down(value: usize, alignment: usize) -> usize {
866 let mask = alignment - 1;
867 value & !mask
868 }
869
870 #[inline]
872 pub const fn is_aligned(value: usize, alignment: usize) -> bool {
873 value & (alignment - 1) == 0
874 }
875
876 #[inline]
878 pub const fn padding_for(offset: usize, alignment: usize) -> usize {
879 let misalignment = offset & (alignment - 1);
880 if misalignment == 0 {
881 0
882 } else {
883 alignment - misalignment
884 }
885 }
886}
887
888#[cfg(test)]
889mod tests {
890 use super::*;
891
892 #[test]
893 fn test_pinned_memory() {
894 let mut mem = PinnedMemory::<f32>::new(1024).unwrap();
895 assert_eq!(mem.len(), 1024);
896 assert_eq!(mem.size_bytes(), 1024 * 4);
897
898 let slice = mem.as_mut_slice();
900 for (i, v) in slice.iter_mut().enumerate() {
901 *v = i as f32;
902 }
903
904 assert_eq!(mem.as_slice()[42], 42.0);
906 }
907
908 #[test]
909 fn test_pinned_memory_from_slice() {
910 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
911 let mem = PinnedMemory::from_slice(&data).unwrap();
912 assert_eq!(mem.as_slice(), &data[..]);
913 }
914
915 #[test]
916 fn test_memory_pool() {
917 let pool = MemoryPool::new("test", 1024, 10);
918
919 let buf1 = pool.allocate();
921 assert_eq!(buf1.len(), 1024);
922 drop(buf1);
923
924 let _buf2 = pool.allocate();
926 assert_eq!(pool.hit_rate(), 0.5); }
928
929 #[test]
930 fn test_pool_preallocate() {
931 let pool = MemoryPool::new("test", 1024, 10);
932 pool.preallocate(5);
933 assert_eq!(pool.current_size(), 5);
934
935 for _ in 0..5 {
937 let _ = pool.allocate();
938 }
939 assert_eq!(pool.hit_rate(), 1.0);
940 }
941
942 #[test]
943 fn test_align_up() {
944 use align::*;
945
946 assert_eq!(align_up(0, 64), 0);
947 assert_eq!(align_up(1, 64), 64);
948 assert_eq!(align_up(64, 64), 64);
949 assert_eq!(align_up(65, 64), 128);
950 }
951
952 #[test]
953 fn test_is_aligned() {
954 use align::*;
955
956 assert!(is_aligned(0, 64));
957 assert!(is_aligned(64, 64));
958 assert!(is_aligned(128, 64));
959 assert!(!is_aligned(1, 64));
960 assert!(!is_aligned(63, 64));
961 }
962
963 #[test]
964 fn test_padding_for() {
965 use align::*;
966
967 assert_eq!(padding_for(0, 64), 0);
968 assert_eq!(padding_for(1, 64), 63);
969 assert_eq!(padding_for(63, 64), 1);
970 assert_eq!(padding_for(64, 64), 0);
971 }
972
973 #[test]
978 fn test_size_bucket_sizes() {
979 assert_eq!(SizeBucket::Tiny.size(), 256);
980 assert_eq!(SizeBucket::Small.size(), 1024);
981 assert_eq!(SizeBucket::Medium.size(), 4096);
982 assert_eq!(SizeBucket::Large.size(), 16384);
983 assert_eq!(SizeBucket::Huge.size(), 65536);
984 }
985
986 #[test]
987 fn test_size_bucket_selection() {
988 assert_eq!(SizeBucket::for_size(0), SizeBucket::Tiny);
990 assert_eq!(SizeBucket::for_size(256), SizeBucket::Tiny);
991 assert_eq!(SizeBucket::for_size(257), SizeBucket::Small);
992 assert_eq!(SizeBucket::for_size(1024), SizeBucket::Small);
993 assert_eq!(SizeBucket::for_size(1025), SizeBucket::Medium);
994 assert_eq!(SizeBucket::for_size(4096), SizeBucket::Medium);
995 assert_eq!(SizeBucket::for_size(4097), SizeBucket::Large);
996 assert_eq!(SizeBucket::for_size(16384), SizeBucket::Large);
997 assert_eq!(SizeBucket::for_size(16385), SizeBucket::Huge);
998 assert_eq!(SizeBucket::for_size(100000), SizeBucket::Huge);
999 }
1000
1001 #[test]
1002 fn test_size_bucket_upgrade_downgrade() {
1003 assert_eq!(SizeBucket::Tiny.upgrade(), SizeBucket::Small);
1004 assert_eq!(SizeBucket::Small.upgrade(), SizeBucket::Medium);
1005 assert_eq!(SizeBucket::Medium.upgrade(), SizeBucket::Large);
1006 assert_eq!(SizeBucket::Large.upgrade(), SizeBucket::Huge);
1007 assert_eq!(SizeBucket::Huge.upgrade(), SizeBucket::Huge); assert_eq!(SizeBucket::Tiny.downgrade(), SizeBucket::Tiny); assert_eq!(SizeBucket::Small.downgrade(), SizeBucket::Tiny);
1011 assert_eq!(SizeBucket::Medium.downgrade(), SizeBucket::Small);
1012 assert_eq!(SizeBucket::Large.downgrade(), SizeBucket::Medium);
1013 assert_eq!(SizeBucket::Huge.downgrade(), SizeBucket::Large);
1014 }
1015
1016 #[test]
1017 fn test_stratified_pool_allocation() {
1018 let pool = StratifiedMemoryPool::new("test");
1019
1020 let buf1 = pool.allocate(100); let buf2 = pool.allocate(500); let buf3 = pool.allocate(2000); assert_eq!(buf1.bucket(), SizeBucket::Tiny);
1026 assert_eq!(buf2.bucket(), SizeBucket::Small);
1027 assert_eq!(buf3.bucket(), SizeBucket::Medium);
1028
1029 assert_eq!(buf1.capacity(), 256);
1031 assert_eq!(buf2.capacity(), 1024);
1032 assert_eq!(buf3.capacity(), 4096);
1033 }
1034
1035 #[test]
1036 fn test_stratified_pool_reuse() {
1037 let pool = StratifiedMemoryPool::new("test");
1038
1039 {
1041 let _buf = pool.allocate(100);
1042 }
1043 {
1047 let _buf = pool.allocate(100);
1048 }
1049
1050 let stats = pool.stats();
1051 assert_eq!(stats.total_allocations, 2);
1052 assert_eq!(stats.total_hits, 1);
1053 assert!((stats.hit_rate() - 0.5).abs() < 0.001);
1054 }
1055
1056 #[test]
1057 fn test_stratified_pool_stats_per_bucket() {
1058 let pool = StratifiedMemoryPool::new("test");
1059
1060 let _buf1 = pool.allocate(100); let _buf2 = pool.allocate(500); let _buf3 = pool.allocate(100); let stats = pool.stats();
1066 assert_eq!(stats.total_allocations, 3);
1067 assert_eq!(
1068 stats.allocations_per_bucket.get(&SizeBucket::Tiny),
1069 Some(&2)
1070 );
1071 assert_eq!(
1072 stats.allocations_per_bucket.get(&SizeBucket::Small),
1073 Some(&1)
1074 );
1075 }
1076
1077 #[test]
1078 fn test_stratified_pool_preallocate() {
1079 let pool = StratifiedMemoryPool::new("test");
1080
1081 pool.preallocate(SizeBucket::Medium, 5);
1082 assert_eq!(pool.bucket_size(SizeBucket::Medium), 5);
1083 assert_eq!(pool.bucket_size(SizeBucket::Tiny), 0);
1084
1085 for _ in 0..5 {
1087 let _buf = pool.allocate(2000);
1088 }
1089
1090 let stats = pool.stats();
1091 assert_eq!(stats.hits_per_bucket.get(&SizeBucket::Medium), Some(&5));
1092 }
1093
1094 #[test]
1095 fn test_stratified_pool_shrink() {
1096 let pool = StratifiedMemoryPool::new("test");
1097
1098 pool.preallocate_all(10);
1100 assert_eq!(pool.total_pooled(), 50); pool.shrink_to(2);
1103 assert_eq!(pool.total_pooled(), 10); }
1105
1106 #[test]
1107 fn test_stratified_buffer_deref() {
1108 let pool = StratifiedMemoryPool::new("test");
1109
1110 let mut buf = pool.allocate(100);
1111
1112 buf[0] = 42;
1114 buf[1] = 43;
1115
1116 assert_eq!(buf[0], 42);
1118 assert_eq!(buf[1], 43);
1119 }
1120
1121 #[test]
1126 fn test_pressure_handler_no_reaction() {
1127 let handler = PressureHandler::no_reaction();
1128 assert_eq!(handler.current_level(), MemoryPressureLevel::Normal);
1129
1130 let result = handler.on_pressure_change(MemoryPressureLevel::Critical, 10);
1131 assert!(result.is_none());
1132 }
1133
1134 #[test]
1135 fn test_pressure_handler_shrink() {
1136 let handler = PressureHandler::shrink_to(0.5);
1137
1138 let result = handler.on_pressure_change(MemoryPressureLevel::Critical, 10);
1140 assert!(result.is_some());
1141 assert!(result.unwrap() >= 1);
1144 }
1145
1146 #[test]
1147 fn test_pressure_handler_callback() {
1148 use std::sync::atomic::{AtomicBool, Ordering};
1149 use std::sync::Arc;
1150
1151 let called = Arc::new(AtomicBool::new(false));
1152 let called_clone = called.clone();
1153
1154 let handler = PressureHandler::with_callback(move |level| {
1155 if level == MemoryPressureLevel::Critical {
1156 called_clone.store(true, Ordering::SeqCst);
1157 }
1158 });
1159
1160 handler.on_pressure_change(MemoryPressureLevel::Critical, 10);
1161 assert!(called.load(Ordering::SeqCst));
1162 }
1163
1164 #[test]
1165 fn test_pressure_handler_only_reacts_to_increase() {
1166 let handler = PressureHandler::shrink_to(0.5);
1167
1168 handler.on_pressure_change(MemoryPressureLevel::Critical, 10);
1170
1171 let result = handler.on_pressure_change(MemoryPressureLevel::Normal, 10);
1173 assert!(result.is_none());
1174 }
1175
1176 #[test]
1177 fn test_pressure_handler_level_tracking() {
1178 let handler = PressureHandler::no_reaction();
1179
1180 assert_eq!(handler.current_level(), MemoryPressureLevel::Normal);
1181
1182 handler.on_pressure_change(MemoryPressureLevel::Warning, 10);
1183 assert_eq!(handler.current_level(), MemoryPressureLevel::Warning);
1184
1185 handler.on_pressure_change(MemoryPressureLevel::Critical, 10);
1186 assert_eq!(handler.current_level(), MemoryPressureLevel::Critical);
1187 }
1188
1189 #[test]
1190 fn test_pressure_reaction_debug() {
1191 let none = PressureReaction::None;
1192 assert!(format!("{:?}", none).contains("None"));
1193
1194 let shrink = PressureReaction::Shrink {
1195 target_utilization: 0.5,
1196 };
1197 assert!(format!("{:?}", shrink).contains("0.5"));
1198
1199 let callback = PressureReaction::Callback(Box::new(|_| {}));
1200 assert!(format!("{:?}", callback).contains("Callback"));
1201 }
1202
1203 #[test]
1204 fn test_pressure_handler_debug() {
1205 let handler = PressureHandler::shrink_to(0.3);
1206 let debug_str = format!("{:?}", handler);
1207 assert!(debug_str.contains("PressureHandler"));
1208 assert!(debug_str.contains("Shrink"));
1209 }
1210
1211 #[test]
1212 fn test_pressure_severity_values() {
1213 let normal = PressureHandler::pressure_severity(MemoryPressureLevel::Normal);
1215 let elevated = PressureHandler::pressure_severity(MemoryPressureLevel::Elevated);
1216 let warning = PressureHandler::pressure_severity(MemoryPressureLevel::Warning);
1217 let critical = PressureHandler::pressure_severity(MemoryPressureLevel::Critical);
1218 let oom = PressureHandler::pressure_severity(MemoryPressureLevel::OutOfMemory);
1219
1220 assert!(normal < elevated);
1221 assert!(elevated < warning);
1222 assert!(warning < critical);
1223 assert!(critical < oom);
1224 assert!(oom <= 1.0);
1225 }
1226}