1#![allow(dead_code)]
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
48pub enum MemoryPressureLevel {
49 Low,
51 Medium,
53 High,
55 Critical,
57}
58
59impl Default for MemoryPressureLevel {
60 fn default() -> Self {
61 Self::Low
62 }
63}
64
65#[derive(Debug, Clone, Copy, PartialEq)]
74pub struct PressureThresholds {
75 pub medium_watermark: f64,
77 pub high_watermark: f64,
79 pub critical_watermark: f64,
81}
82
83impl Default for PressureThresholds {
84 fn default() -> Self {
85 Self {
86 medium_watermark: 0.5,
87 high_watermark: 0.75,
88 critical_watermark: 0.9,
89 }
90 }
91}
92
93impl PressureThresholds {
94 #[must_use]
101 pub fn new(medium: f64, high: f64, critical: f64) -> Self {
102 assert!(
103 (0.0..=1.0).contains(&medium)
104 && (0.0..=1.0).contains(&high)
105 && (0.0..=1.0).contains(&critical),
106 "Watermarks must be in [0.0, 1.0]"
107 );
108 assert!(
109 medium <= high && high <= critical,
110 "Watermarks must be in ascending order"
111 );
112 Self {
113 medium_watermark: medium,
114 high_watermark: high,
115 critical_watermark: critical,
116 }
117 }
118
119 #[must_use]
121 pub fn level_for_fraction(&self, in_use_fraction: f64) -> MemoryPressureLevel {
122 if in_use_fraction >= self.critical_watermark {
123 MemoryPressureLevel::Critical
124 } else if in_use_fraction >= self.high_watermark {
125 MemoryPressureLevel::High
126 } else if in_use_fraction >= self.medium_watermark {
127 MemoryPressureLevel::Medium
128 } else {
129 MemoryPressureLevel::Low
130 }
131 }
132}
133
134pub type MemoryPressureCallback = Box<dyn Fn(MemoryPressureLevel) + Send + Sync>;
142
143#[derive(Debug, Clone, PartialEq, Eq)]
149pub struct BufferDesc {
150 pub size_bytes: usize,
152 pub alignment: usize,
154 pub pool_id: u32,
156}
157
158impl BufferDesc {
159 #[must_use]
161 pub fn new(size_bytes: usize, alignment: usize, pool_id: u32) -> Self {
162 Self {
163 size_bytes,
164 alignment,
165 pool_id,
166 }
167 }
168
169 #[must_use]
171 pub fn is_page_aligned(&self) -> bool {
172 self.alignment == 4096
173 }
174
175 #[must_use]
181 pub fn slots_needed(&self, slot_size: usize) -> usize {
182 assert!(slot_size > 0, "slot_size must be non-zero");
183 self.size_bytes.div_ceil(slot_size)
184 }
185}
186
187#[derive(Debug)]
189pub struct PooledBuffer {
190 pub id: u64,
192 pub data: Vec<u8>,
194 pub desc: BufferDesc,
196 pub in_use: bool,
198}
199
200impl PooledBuffer {
201 #[must_use]
203 pub fn new(id: u64, desc: BufferDesc) -> Self {
204 let data = vec![0u8; desc.size_bytes];
205 Self {
206 id,
207 data,
208 desc,
209 in_use: false,
210 }
211 }
212
213 pub fn reset(&mut self) {
215 self.data.fill(0);
216 self.in_use = false;
217 }
218
219 #[must_use]
221 pub fn available_size(&self) -> usize {
222 self.data.len()
223 }
224}
225
226pub struct BufferPool {
236 pub buffers: Vec<PooledBuffer>,
238 pub next_id: u64,
240 thresholds: Option<PressureThresholds>,
242 last_pressure: MemoryPressureLevel,
244 pressure_callbacks: Vec<MemoryPressureCallback>,
246}
247
248impl std::fmt::Debug for BufferPool {
249 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
250 f.debug_struct("BufferPool")
251 .field("total", &self.buffers.len())
252 .field("available", &self.available_count())
253 .field("last_pressure", &self.last_pressure)
254 .finish()
255 }
256}
257
258impl BufferPool {
259 #[must_use]
264 pub fn new(count: usize, buf_size: usize) -> Self {
265 let mut buffers = Vec::with_capacity(count);
266 for id in 0..count as u64 {
267 let desc = BufferDesc::new(buf_size, 64, 0);
268 buffers.push(PooledBuffer::new(id, desc));
269 }
270 Self {
271 buffers,
272 next_id: count as u64,
273 thresholds: None,
274 last_pressure: MemoryPressureLevel::Low,
275 pressure_callbacks: Vec::new(),
276 }
277 }
278
279 #[must_use]
283 pub fn with_pressure(count: usize, buf_size: usize, thresholds: PressureThresholds) -> Self {
284 let mut pool = Self::new(count, buf_size);
285 pool.thresholds = Some(thresholds);
286 pool
287 }
288
289 pub fn add_pressure_callback(&mut self, cb: MemoryPressureCallback) {
294 self.pressure_callbacks.push(cb);
295 }
296
297 #[must_use]
301 fn in_use_fraction(&self) -> f64 {
302 let total = self.buffers.len();
303 if total == 0 {
304 return 0.0;
305 }
306 let in_use = self.buffers.iter().filter(|b| b.in_use).count();
307 in_use as f64 / total as f64
308 }
309
310 #[must_use]
314 pub fn current_pressure_level(&self) -> MemoryPressureLevel {
315 match &self.thresholds {
316 None => MemoryPressureLevel::Low,
317 Some(t) => t.level_for_fraction(self.in_use_fraction()),
318 }
319 }
320
321 fn notify_pressure(&mut self) {
324 let current = self.current_pressure_level();
325 if current != self.last_pressure {
326 self.last_pressure = current;
327 for cb in &self.pressure_callbacks {
328 cb(current);
329 }
330 }
331 }
332
333 #[must_use]
340 pub fn acquire(&mut self) -> Option<u64> {
341 let acquired = self.buffers.iter_mut().find(|b| !b.in_use).map(|buf| {
342 buf.in_use = true;
343 buf.id
344 });
345 if acquired.is_some() {
346 self.notify_pressure();
347 }
348 acquired
349 }
350
351 pub fn release(&mut self, id: u64) {
356 if let Some(buf) = self.buffers.iter_mut().find(|b| b.id == id) {
357 buf.reset();
358 }
359 self.notify_pressure();
360 }
361
362 pub fn shrink_to(&mut self, target_count: usize) -> usize {
381 let mut removed = 0usize;
382 let mut i = self.buffers.len();
384 while i > 0 && self.buffers.len() > target_count {
385 i -= 1;
386 if !self.buffers[i].in_use {
387 self.buffers.remove(i);
388 removed += 1;
389 }
390 }
391 if removed > 0 {
392 self.notify_pressure();
393 }
394 removed
395 }
396
397 pub fn auto_shrink(&mut self) -> usize {
403 let current_level = self.current_pressure_level();
404 if current_level != MemoryPressureLevel::Low {
405 return 0;
406 }
407 let total = self.buffers.len();
408 let available = self.available_count();
409 if total == 0 || available <= total / 2 {
410 return 0;
411 }
412 let target = (total / 2).max(1);
413 self.shrink_to(target)
414 }
415
416 #[must_use]
420 pub fn available_count(&self) -> usize {
421 self.buffers.iter().filter(|b| !b.in_use).count()
422 }
423
424 #[must_use]
426 pub fn total_count(&self) -> usize {
427 self.buffers.len()
428 }
429
430 #[must_use]
432 pub fn in_use_count(&self) -> usize {
433 self.buffers.iter().filter(|b| b.in_use).count()
434 }
435}
436
437#[cfg(test)]
442mod tests {
443 use super::*;
444 use std::sync::{Arc, Mutex};
445
446 #[test]
449 fn test_buffer_desc_new() {
450 let desc = BufferDesc::new(1024, 64, 1);
451 assert_eq!(desc.size_bytes, 1024);
452 assert_eq!(desc.alignment, 64);
453 assert_eq!(desc.pool_id, 1);
454 }
455
456 #[test]
457 fn test_buffer_desc_is_page_aligned_true() {
458 let desc = BufferDesc::new(8192, 4096, 0);
459 assert!(desc.is_page_aligned());
460 }
461
462 #[test]
463 fn test_buffer_desc_is_page_aligned_false() {
464 let desc = BufferDesc::new(8192, 64, 0);
465 assert!(!desc.is_page_aligned());
466 }
467
468 #[test]
469 fn test_buffer_desc_slots_needed_exact() {
470 let desc = BufferDesc::new(1024, 64, 0);
471 assert_eq!(desc.slots_needed(512), 2);
472 }
473
474 #[test]
475 fn test_buffer_desc_slots_needed_round_up() {
476 let desc = BufferDesc::new(1025, 64, 0);
477 assert_eq!(desc.slots_needed(512), 3);
478 }
479
480 #[test]
481 fn test_buffer_desc_slots_needed_single_slot() {
482 let desc = BufferDesc::new(100, 64, 0);
483 assert_eq!(desc.slots_needed(200), 1);
484 }
485
486 #[test]
489 fn test_pooled_buffer_initial_state() {
490 let desc = BufferDesc::new(256, 64, 0);
491 let buf = PooledBuffer::new(42, desc);
492 assert_eq!(buf.id, 42);
493 assert!(!buf.in_use);
494 assert_eq!(buf.available_size(), 256);
495 assert!(buf.data.iter().all(|&b| b == 0));
496 }
497
498 #[test]
499 fn test_pooled_buffer_reset() {
500 let desc = BufferDesc::new(4, 64, 0);
501 let mut buf = PooledBuffer::new(1, desc);
502 buf.in_use = true;
503 buf.data[0] = 0xFF;
504 buf.reset();
505 assert!(!buf.in_use);
506 assert!(buf.data.iter().all(|&b| b == 0));
507 }
508
509 #[test]
510 fn test_pooled_buffer_available_size() {
511 let desc = BufferDesc::new(512, 64, 0);
512 let buf = PooledBuffer::new(0, desc);
513 assert_eq!(buf.available_size(), 512);
514 }
515
516 #[test]
519 fn test_pool_new() {
520 let pool = BufferPool::new(4, 1024);
521 assert_eq!(pool.total_count(), 4);
522 assert_eq!(pool.available_count(), 4);
523 }
524
525 #[test]
526 fn test_pool_acquire_returns_id() {
527 let mut pool = BufferPool::new(2, 256);
528 let id = pool.acquire();
529 assert!(id.is_some());
530 }
531
532 #[test]
533 fn test_pool_acquire_exhausts_buffers() {
534 let mut pool = BufferPool::new(2, 256);
535 let _id1 = pool.acquire().expect("acquire should succeed");
536 let _id2 = pool.acquire().expect("acquire should succeed");
537 assert!(pool.acquire().is_none());
538 }
539
540 #[test]
541 fn test_pool_available_count_decrements_on_acquire() {
542 let mut pool = BufferPool::new(3, 64);
543 assert_eq!(pool.available_count(), 3);
544 let _ = pool.acquire();
545 assert_eq!(pool.available_count(), 2);
546 let _ = pool.acquire();
547 assert_eq!(pool.available_count(), 1);
548 }
549
550 #[test]
551 fn test_pool_release_makes_buffer_available() {
552 let mut pool = BufferPool::new(1, 64);
553 let id = pool.acquire().expect("acquire should succeed");
554 assert_eq!(pool.available_count(), 0);
555 pool.release(id);
556 assert_eq!(pool.available_count(), 1);
557 }
558
559 #[test]
560 fn test_pool_release_unknown_id_is_noop() {
561 let mut pool = BufferPool::new(2, 64);
562 let before = pool.available_count();
563 pool.release(999);
564 assert_eq!(pool.available_count(), before);
565 }
566
567 #[test]
568 fn test_pool_total_count_unchanged_after_ops() {
569 let mut pool = BufferPool::new(5, 128);
570 let ids: Vec<u64> = (0..5).filter_map(|_| pool.acquire()).collect();
571 assert_eq!(pool.total_count(), 5);
572 for id in ids {
573 pool.release(id);
574 }
575 assert_eq!(pool.total_count(), 5);
576 }
577
578 #[test]
581 fn test_pressure_thresholds_default() {
582 let t = PressureThresholds::default();
583 assert_eq!(t.level_for_fraction(0.0), MemoryPressureLevel::Low);
584 assert_eq!(t.level_for_fraction(0.5), MemoryPressureLevel::Medium);
585 assert_eq!(t.level_for_fraction(0.75), MemoryPressureLevel::High);
586 assert_eq!(t.level_for_fraction(0.9), MemoryPressureLevel::Critical);
587 assert_eq!(t.level_for_fraction(1.0), MemoryPressureLevel::Critical);
588 }
589
590 #[test]
591 fn test_pressure_thresholds_custom() {
592 let t = PressureThresholds::new(0.4, 0.6, 0.8);
593 assert_eq!(t.level_for_fraction(0.3), MemoryPressureLevel::Low);
594 assert_eq!(t.level_for_fraction(0.5), MemoryPressureLevel::Medium);
595 assert_eq!(t.level_for_fraction(0.7), MemoryPressureLevel::High);
596 assert_eq!(t.level_for_fraction(0.85), MemoryPressureLevel::Critical);
597 }
598
599 #[test]
600 #[should_panic(expected = "Watermarks must be in ascending order")]
601 fn test_pressure_thresholds_out_of_order_panics() {
602 let _ = PressureThresholds::new(0.8, 0.5, 0.9);
603 }
604
605 #[test]
608 fn test_pool_initial_pressure_level_low() {
609 let pool = BufferPool::with_pressure(4, 64, PressureThresholds::default());
610 assert_eq!(pool.current_pressure_level(), MemoryPressureLevel::Low);
611 }
612
613 #[test]
614 fn test_pool_pressure_level_increases_with_usage() {
615 let mut pool = BufferPool::with_pressure(4, 64, PressureThresholds::default());
616 let _id0 = pool.acquire();
618 let _id1 = pool.acquire();
619 assert_eq!(pool.current_pressure_level(), MemoryPressureLevel::Medium);
620 let _id2 = pool.acquire();
622 assert_eq!(pool.current_pressure_level(), MemoryPressureLevel::High);
623 let _id3 = pool.acquire();
625 assert_eq!(pool.current_pressure_level(), MemoryPressureLevel::Critical);
626 }
627
628 #[test]
629 fn test_pool_pressure_level_decreases_on_release() {
630 let mut pool = BufferPool::with_pressure(4, 64, PressureThresholds::default());
631 let id0 = pool.acquire().expect("should acquire");
632 let id1 = pool.acquire().expect("should acquire");
633 let id2 = pool.acquire().expect("should acquire");
634 let id3 = pool.acquire().expect("should acquire");
635 assert_eq!(pool.current_pressure_level(), MemoryPressureLevel::Critical);
636 pool.release(id3);
637 assert_eq!(pool.current_pressure_level(), MemoryPressureLevel::High);
638 pool.release(id2);
639 assert_eq!(pool.current_pressure_level(), MemoryPressureLevel::Medium);
640 pool.release(id1);
641 pool.release(id0);
642 assert_eq!(pool.current_pressure_level(), MemoryPressureLevel::Low);
643 }
644
645 #[test]
648 fn test_pressure_callback_fired_on_transition() {
649 let events: Arc<Mutex<Vec<MemoryPressureLevel>>> = Arc::new(Mutex::new(Vec::new()));
650 let events_clone = Arc::clone(&events);
651
652 let mut pool = BufferPool::with_pressure(4, 64, PressureThresholds::default());
653 pool.add_pressure_callback(Box::new(move |level| {
654 events_clone.lock().expect("lock").push(level);
655 }));
656
657 let _id0 = pool.acquire();
659 let _id1 = pool.acquire();
660 let _id2 = pool.acquire();
662 let _id3 = pool.acquire();
664
665 let recorded = events.lock().expect("lock").clone();
666 assert_eq!(
667 recorded,
668 vec![
669 MemoryPressureLevel::Medium,
670 MemoryPressureLevel::High,
671 MemoryPressureLevel::Critical,
672 ]
673 );
674 }
675
676 #[test]
677 fn test_pressure_callback_not_fired_on_same_level() {
678 let events: Arc<Mutex<Vec<MemoryPressureLevel>>> = Arc::new(Mutex::new(Vec::new()));
679 let events_clone = Arc::clone(&events);
680
681 let mut pool = BufferPool::with_pressure(10, 64, PressureThresholds::default());
683 pool.add_pressure_callback(Box::new(move |level| {
684 events_clone.lock().expect("lock").push(level);
685 }));
686
687 let _a = pool.acquire(); let _b = pool.acquire(); let recorded = events.lock().expect("lock").clone();
692 assert!(recorded.is_empty());
694 }
695
696 #[test]
699 fn test_shrink_to_removes_free_buffers() {
700 let mut pool = BufferPool::new(8, 64);
701 let removed = pool.shrink_to(4);
702 assert_eq!(removed, 4);
703 assert_eq!(pool.total_count(), 4);
704 assert_eq!(pool.available_count(), 4);
705 }
706
707 #[test]
708 fn test_shrink_to_does_not_remove_in_use_buffers() {
709 let mut pool = BufferPool::new(4, 64);
710 let id0 = pool.acquire().expect("should acquire");
711 let id1 = pool.acquire().expect("should acquire");
712 let removed = pool.shrink_to(1);
714 assert_eq!(removed, 2);
715 assert_eq!(pool.total_count(), 2);
716 assert_eq!(pool.in_use_count(), 2);
717 pool.release(id0);
718 pool.release(id1);
719 }
720
721 #[test]
722 fn test_shrink_to_noop_when_already_at_or_below_target() {
723 let mut pool = BufferPool::new(4, 64);
724 let removed = pool.shrink_to(4);
725 assert_eq!(removed, 0);
726 assert_eq!(pool.total_count(), 4);
727
728 let removed2 = pool.shrink_to(10);
729 assert_eq!(removed2, 0);
730 assert_eq!(pool.total_count(), 4);
731 }
732
733 #[test]
734 fn test_auto_shrink_when_low_pressure() {
735 let mut pool = BufferPool::with_pressure(8, 64, PressureThresholds::default());
736 let removed = pool.auto_shrink();
738 assert!(removed > 0);
739 assert!(pool.total_count() < 8);
740 }
741
742 #[test]
743 fn test_auto_shrink_does_not_shrink_under_pressure() {
744 let mut pool = BufferPool::with_pressure(4, 64, PressureThresholds::default());
745 let _id0 = pool.acquire();
747 let _id1 = pool.acquire();
748 let removed = pool.auto_shrink();
749 assert_eq!(removed, 0);
750 }
751
752 #[test]
753 fn test_in_use_count() {
754 let mut pool = BufferPool::new(4, 64);
755 assert_eq!(pool.in_use_count(), 0);
756 let _ = pool.acquire();
757 let _ = pool.acquire();
758 assert_eq!(pool.in_use_count(), 2);
759 }
760
761 #[test]
764 fn test_no_thresholds_always_low() {
765 let mut pool = BufferPool::new(2, 64);
766 let _ = pool.acquire();
767 let _ = pool.acquire();
768 assert_eq!(pool.current_pressure_level(), MemoryPressureLevel::Low);
769 }
770}