#![allow(dead_code)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum MemoryPressureLevel {
Low,
Medium,
High,
Critical,
}
impl Default for MemoryPressureLevel {
fn default() -> Self {
Self::Low
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct PressureThresholds {
pub medium_watermark: f64,
pub high_watermark: f64,
pub critical_watermark: f64,
}
impl Default for PressureThresholds {
fn default() -> Self {
Self {
medium_watermark: 0.5,
high_watermark: 0.75,
critical_watermark: 0.9,
}
}
}
impl PressureThresholds {
#[must_use]
pub fn new(medium: f64, high: f64, critical: f64) -> Self {
assert!(
(0.0..=1.0).contains(&medium)
&& (0.0..=1.0).contains(&high)
&& (0.0..=1.0).contains(&critical),
"Watermarks must be in [0.0, 1.0]"
);
assert!(
medium <= high && high <= critical,
"Watermarks must be in ascending order"
);
Self {
medium_watermark: medium,
high_watermark: high,
critical_watermark: critical,
}
}
#[must_use]
pub fn level_for_fraction(&self, in_use_fraction: f64) -> MemoryPressureLevel {
if in_use_fraction >= self.critical_watermark {
MemoryPressureLevel::Critical
} else if in_use_fraction >= self.high_watermark {
MemoryPressureLevel::High
} else if in_use_fraction >= self.medium_watermark {
MemoryPressureLevel::Medium
} else {
MemoryPressureLevel::Low
}
}
}
pub type MemoryPressureCallback = Box<dyn Fn(MemoryPressureLevel) + Send + Sync>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BufferDesc {
pub size_bytes: usize,
pub alignment: usize,
pub pool_id: u32,
}
impl BufferDesc {
#[must_use]
pub fn new(size_bytes: usize, alignment: usize, pool_id: u32) -> Self {
Self {
size_bytes,
alignment,
pool_id,
}
}
#[must_use]
pub fn is_page_aligned(&self) -> bool {
self.alignment == 4096
}
#[must_use]
pub fn slots_needed(&self, slot_size: usize) -> usize {
assert!(slot_size > 0, "slot_size must be non-zero");
self.size_bytes.div_ceil(slot_size)
}
}
#[derive(Debug)]
pub struct PooledBuffer {
pub id: u64,
pub data: Vec<u8>,
pub desc: BufferDesc,
pub in_use: bool,
}
impl PooledBuffer {
#[must_use]
pub fn new(id: u64, desc: BufferDesc) -> Self {
let data = vec![0u8; desc.size_bytes];
Self {
id,
data,
desc,
in_use: false,
}
}
pub fn reset(&mut self) {
self.data.fill(0);
self.in_use = false;
}
#[must_use]
pub fn available_size(&self) -> usize {
self.data.len()
}
}
pub struct BufferPool {
pub buffers: Vec<PooledBuffer>,
pub next_id: u64,
thresholds: Option<PressureThresholds>,
last_pressure: MemoryPressureLevel,
pressure_callbacks: Vec<MemoryPressureCallback>,
}
impl std::fmt::Debug for BufferPool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BufferPool")
.field("total", &self.buffers.len())
.field("available", &self.available_count())
.field("last_pressure", &self.last_pressure)
.finish()
}
}
impl BufferPool {
#[must_use]
pub fn new(count: usize, buf_size: usize) -> Self {
let mut buffers = Vec::with_capacity(count);
for id in 0..count as u64 {
let desc = BufferDesc::new(buf_size, 64, 0);
buffers.push(PooledBuffer::new(id, desc));
}
Self {
buffers,
next_id: count as u64,
thresholds: None,
last_pressure: MemoryPressureLevel::Low,
pressure_callbacks: Vec::new(),
}
}
#[must_use]
pub fn with_pressure(count: usize, buf_size: usize, thresholds: PressureThresholds) -> Self {
let mut pool = Self::new(count, buf_size);
pool.thresholds = Some(thresholds);
pool
}
pub fn add_pressure_callback(&mut self, cb: MemoryPressureCallback) {
self.pressure_callbacks.push(cb);
}
#[must_use]
fn in_use_fraction(&self) -> f64 {
let total = self.buffers.len();
if total == 0 {
return 0.0;
}
let in_use = self.buffers.iter().filter(|b| b.in_use).count();
in_use as f64 / total as f64
}
#[must_use]
pub fn current_pressure_level(&self) -> MemoryPressureLevel {
match &self.thresholds {
None => MemoryPressureLevel::Low,
Some(t) => t.level_for_fraction(self.in_use_fraction()),
}
}
fn notify_pressure(&mut self) {
let current = self.current_pressure_level();
if current != self.last_pressure {
self.last_pressure = current;
for cb in &self.pressure_callbacks {
cb(current);
}
}
}
#[must_use]
pub fn acquire(&mut self) -> Option<u64> {
let acquired = self.buffers.iter_mut().find(|b| !b.in_use).map(|buf| {
buf.in_use = true;
buf.id
});
if acquired.is_some() {
self.notify_pressure();
}
acquired
}
pub fn release(&mut self, id: u64) {
if let Some(buf) = self.buffers.iter_mut().find(|b| b.id == id) {
buf.reset();
}
self.notify_pressure();
}
pub fn shrink_to(&mut self, target_count: usize) -> usize {
let mut removed = 0usize;
let mut i = self.buffers.len();
while i > 0 && self.buffers.len() > target_count {
i -= 1;
if !self.buffers[i].in_use {
self.buffers.remove(i);
removed += 1;
}
}
if removed > 0 {
self.notify_pressure();
}
removed
}
pub fn auto_shrink(&mut self) -> usize {
let current_level = self.current_pressure_level();
if current_level != MemoryPressureLevel::Low {
return 0;
}
let total = self.buffers.len();
let available = self.available_count();
if total == 0 || available <= total / 2 {
return 0;
}
let target = (total / 2).max(1);
self.shrink_to(target)
}
#[must_use]
pub fn available_count(&self) -> usize {
self.buffers.iter().filter(|b| !b.in_use).count()
}
#[must_use]
pub fn total_count(&self) -> usize {
self.buffers.len()
}
#[must_use]
pub fn in_use_count(&self) -> usize {
self.buffers.iter().filter(|b| b.in_use).count()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Mutex};
#[test]
fn test_buffer_desc_new() {
let desc = BufferDesc::new(1024, 64, 1);
assert_eq!(desc.size_bytes, 1024);
assert_eq!(desc.alignment, 64);
assert_eq!(desc.pool_id, 1);
}
#[test]
fn test_buffer_desc_is_page_aligned_true() {
let desc = BufferDesc::new(8192, 4096, 0);
assert!(desc.is_page_aligned());
}
#[test]
fn test_buffer_desc_is_page_aligned_false() {
let desc = BufferDesc::new(8192, 64, 0);
assert!(!desc.is_page_aligned());
}
#[test]
fn test_buffer_desc_slots_needed_exact() {
let desc = BufferDesc::new(1024, 64, 0);
assert_eq!(desc.slots_needed(512), 2);
}
#[test]
fn test_buffer_desc_slots_needed_round_up() {
let desc = BufferDesc::new(1025, 64, 0);
assert_eq!(desc.slots_needed(512), 3);
}
#[test]
fn test_buffer_desc_slots_needed_single_slot() {
let desc = BufferDesc::new(100, 64, 0);
assert_eq!(desc.slots_needed(200), 1);
}
#[test]
fn test_pooled_buffer_initial_state() {
let desc = BufferDesc::new(256, 64, 0);
let buf = PooledBuffer::new(42, desc);
assert_eq!(buf.id, 42);
assert!(!buf.in_use);
assert_eq!(buf.available_size(), 256);
assert!(buf.data.iter().all(|&b| b == 0));
}
#[test]
fn test_pooled_buffer_reset() {
let desc = BufferDesc::new(4, 64, 0);
let mut buf = PooledBuffer::new(1, desc);
buf.in_use = true;
buf.data[0] = 0xFF;
buf.reset();
assert!(!buf.in_use);
assert!(buf.data.iter().all(|&b| b == 0));
}
#[test]
fn test_pooled_buffer_available_size() {
let desc = BufferDesc::new(512, 64, 0);
let buf = PooledBuffer::new(0, desc);
assert_eq!(buf.available_size(), 512);
}
#[test]
fn test_pool_new() {
let pool = BufferPool::new(4, 1024);
assert_eq!(pool.total_count(), 4);
assert_eq!(pool.available_count(), 4);
}
#[test]
fn test_pool_acquire_returns_id() {
let mut pool = BufferPool::new(2, 256);
let id = pool.acquire();
assert!(id.is_some());
}
#[test]
fn test_pool_acquire_exhausts_buffers() {
let mut pool = BufferPool::new(2, 256);
let _id1 = pool.acquire().expect("acquire should succeed");
let _id2 = pool.acquire().expect("acquire should succeed");
assert!(pool.acquire().is_none());
}
#[test]
fn test_pool_available_count_decrements_on_acquire() {
let mut pool = BufferPool::new(3, 64);
assert_eq!(pool.available_count(), 3);
let _ = pool.acquire();
assert_eq!(pool.available_count(), 2);
let _ = pool.acquire();
assert_eq!(pool.available_count(), 1);
}
#[test]
fn test_pool_release_makes_buffer_available() {
let mut pool = BufferPool::new(1, 64);
let id = pool.acquire().expect("acquire should succeed");
assert_eq!(pool.available_count(), 0);
pool.release(id);
assert_eq!(pool.available_count(), 1);
}
#[test]
fn test_pool_release_unknown_id_is_noop() {
let mut pool = BufferPool::new(2, 64);
let before = pool.available_count();
pool.release(999);
assert_eq!(pool.available_count(), before);
}
#[test]
fn test_pool_total_count_unchanged_after_ops() {
let mut pool = BufferPool::new(5, 128);
let ids: Vec<u64> = (0..5).filter_map(|_| pool.acquire()).collect();
assert_eq!(pool.total_count(), 5);
for id in ids {
pool.release(id);
}
assert_eq!(pool.total_count(), 5);
}
#[test]
fn test_pressure_thresholds_default() {
let t = PressureThresholds::default();
assert_eq!(t.level_for_fraction(0.0), MemoryPressureLevel::Low);
assert_eq!(t.level_for_fraction(0.5), MemoryPressureLevel::Medium);
assert_eq!(t.level_for_fraction(0.75), MemoryPressureLevel::High);
assert_eq!(t.level_for_fraction(0.9), MemoryPressureLevel::Critical);
assert_eq!(t.level_for_fraction(1.0), MemoryPressureLevel::Critical);
}
#[test]
fn test_pressure_thresholds_custom() {
let t = PressureThresholds::new(0.4, 0.6, 0.8);
assert_eq!(t.level_for_fraction(0.3), MemoryPressureLevel::Low);
assert_eq!(t.level_for_fraction(0.5), MemoryPressureLevel::Medium);
assert_eq!(t.level_for_fraction(0.7), MemoryPressureLevel::High);
assert_eq!(t.level_for_fraction(0.85), MemoryPressureLevel::Critical);
}
#[test]
#[should_panic(expected = "Watermarks must be in ascending order")]
fn test_pressure_thresholds_out_of_order_panics() {
let _ = PressureThresholds::new(0.8, 0.5, 0.9);
}
#[test]
fn test_pool_initial_pressure_level_low() {
let pool = BufferPool::with_pressure(4, 64, PressureThresholds::default());
assert_eq!(pool.current_pressure_level(), MemoryPressureLevel::Low);
}
#[test]
fn test_pool_pressure_level_increases_with_usage() {
let mut pool = BufferPool::with_pressure(4, 64, PressureThresholds::default());
let _id0 = pool.acquire();
let _id1 = pool.acquire();
assert_eq!(pool.current_pressure_level(), MemoryPressureLevel::Medium);
let _id2 = pool.acquire();
assert_eq!(pool.current_pressure_level(), MemoryPressureLevel::High);
let _id3 = pool.acquire();
assert_eq!(pool.current_pressure_level(), MemoryPressureLevel::Critical);
}
#[test]
fn test_pool_pressure_level_decreases_on_release() {
let mut pool = BufferPool::with_pressure(4, 64, PressureThresholds::default());
let id0 = pool.acquire().expect("should acquire");
let id1 = pool.acquire().expect("should acquire");
let id2 = pool.acquire().expect("should acquire");
let id3 = pool.acquire().expect("should acquire");
assert_eq!(pool.current_pressure_level(), MemoryPressureLevel::Critical);
pool.release(id3);
assert_eq!(pool.current_pressure_level(), MemoryPressureLevel::High);
pool.release(id2);
assert_eq!(pool.current_pressure_level(), MemoryPressureLevel::Medium);
pool.release(id1);
pool.release(id0);
assert_eq!(pool.current_pressure_level(), MemoryPressureLevel::Low);
}
#[test]
fn test_pressure_callback_fired_on_transition() {
let events: Arc<Mutex<Vec<MemoryPressureLevel>>> = Arc::new(Mutex::new(Vec::new()));
let events_clone = Arc::clone(&events);
let mut pool = BufferPool::with_pressure(4, 64, PressureThresholds::default());
pool.add_pressure_callback(Box::new(move |level| {
events_clone.lock().expect("lock").push(level);
}));
let _id0 = pool.acquire();
let _id1 = pool.acquire();
let _id2 = pool.acquire();
let _id3 = pool.acquire();
let recorded = events.lock().expect("lock").clone();
assert_eq!(
recorded,
vec![
MemoryPressureLevel::Medium,
MemoryPressureLevel::High,
MemoryPressureLevel::Critical,
]
);
}
#[test]
fn test_pressure_callback_not_fired_on_same_level() {
let events: Arc<Mutex<Vec<MemoryPressureLevel>>> = Arc::new(Mutex::new(Vec::new()));
let events_clone = Arc::clone(&events);
let mut pool = BufferPool::with_pressure(10, 64, PressureThresholds::default());
pool.add_pressure_callback(Box::new(move |level| {
events_clone.lock().expect("lock").push(level);
}));
let _a = pool.acquire(); let _b = pool.acquire();
let recorded = events.lock().expect("lock").clone();
assert!(recorded.is_empty());
}
#[test]
fn test_shrink_to_removes_free_buffers() {
let mut pool = BufferPool::new(8, 64);
let removed = pool.shrink_to(4);
assert_eq!(removed, 4);
assert_eq!(pool.total_count(), 4);
assert_eq!(pool.available_count(), 4);
}
#[test]
fn test_shrink_to_does_not_remove_in_use_buffers() {
let mut pool = BufferPool::new(4, 64);
let id0 = pool.acquire().expect("should acquire");
let id1 = pool.acquire().expect("should acquire");
let removed = pool.shrink_to(1);
assert_eq!(removed, 2);
assert_eq!(pool.total_count(), 2);
assert_eq!(pool.in_use_count(), 2);
pool.release(id0);
pool.release(id1);
}
#[test]
fn test_shrink_to_noop_when_already_at_or_below_target() {
let mut pool = BufferPool::new(4, 64);
let removed = pool.shrink_to(4);
assert_eq!(removed, 0);
assert_eq!(pool.total_count(), 4);
let removed2 = pool.shrink_to(10);
assert_eq!(removed2, 0);
assert_eq!(pool.total_count(), 4);
}
#[test]
fn test_auto_shrink_when_low_pressure() {
let mut pool = BufferPool::with_pressure(8, 64, PressureThresholds::default());
let removed = pool.auto_shrink();
assert!(removed > 0);
assert!(pool.total_count() < 8);
}
#[test]
fn test_auto_shrink_does_not_shrink_under_pressure() {
let mut pool = BufferPool::with_pressure(4, 64, PressureThresholds::default());
let _id0 = pool.acquire();
let _id1 = pool.acquire();
let removed = pool.auto_shrink();
assert_eq!(removed, 0);
}
#[test]
fn test_in_use_count() {
let mut pool = BufferPool::new(4, 64);
assert_eq!(pool.in_use_count(), 0);
let _ = pool.acquire();
let _ = pool.acquire();
assert_eq!(pool.in_use_count(), 2);
}
#[test]
fn test_no_thresholds_always_low() {
let mut pool = BufferPool::new(2, 64);
let _ = pool.acquire();
let _ = pool.acquire();
assert_eq!(pool.current_pressure_level(), MemoryPressureLevel::Low);
}
}