use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread::{self, JoinHandle};
use std::time::{Duration, Instant};
use parking_lot::RwLock;
use super::block_storage::BlockStorage;
use super::buffer_manager::BufferManager;
use super::disk_manager::MmapDiskManager;
use super::memory_monitor::{MemoryPressureLevel, MemoryPressureMonitor};
#[derive(Debug, Clone)]
pub struct AdaptivePoolConfig {
pub min_pool_size: usize,
pub max_pool_size: usize,
pub target_memory_fraction: f64,
pub target_hit_rate: f64,
pub min_hit_rate_for_growth: f64,
pub adjustment_interval: Duration,
pub max_growth_step: usize,
pub max_shrink_step: usize,
pub kp: f64,
pub ki: f64,
pub kd: f64,
pub enabled: bool,
}
impl Default for AdaptivePoolConfig {
fn default() -> Self {
Self {
min_pool_size: 16,
max_pool_size: 1024,
target_memory_fraction: 0.25,
target_hit_rate: 0.95,
min_hit_rate_for_growth: 0.90,
adjustment_interval: Duration::from_secs(10),
max_growth_step: 16,
max_shrink_step: 8,
kp: 0.5,
ki: 0.1,
kd: 0.05,
enabled: true,
}
}
}
#[derive(Debug)]
pub struct CacheStats {
hits: AtomicU64,
misses: AtomicU64,
}
impl Default for CacheStats {
fn default() -> Self {
Self::new()
}
}
impl CacheStats {
pub fn new() -> Self {
Self {
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
}
}
#[inline]
pub fn record_hit(&self) {
self.hits.fetch_add(1, Ordering::Relaxed);
}
#[inline]
pub fn record_miss(&self) {
self.misses.fetch_add(1, Ordering::Relaxed);
}
pub fn get_and_reset(&self) -> (f64, u64, u64) {
let hits = self.hits.swap(0, Ordering::Relaxed);
let misses = self.misses.swap(0, Ordering::Relaxed);
let total = hits + misses;
let hit_rate = if total == 0 {
1.0 } else {
hits as f64 / total as f64
};
(hit_rate, hits, misses)
}
pub fn hit_rate(&self) -> f64 {
let hits = self.hits.load(Ordering::Relaxed);
let misses = self.misses.load(Ordering::Relaxed);
let total = hits + misses;
if total == 0 {
1.0
} else {
hits as f64 / total as f64
}
}
pub fn total_accesses(&self) -> u64 {
self.hits.load(Ordering::Relaxed) + self.misses.load(Ordering::Relaxed)
}
pub fn counts(&self) -> (u64, u64) {
(
self.hits.load(Ordering::Relaxed),
self.misses.load(Ordering::Relaxed),
)
}
}
#[derive(Debug)]
struct PidController {
kp: f64,
ki: f64,
kd: f64,
integral: f64,
prev_error: f64,
integral_min: f64,
integral_max: f64,
}
impl PidController {
fn new(kp: f64, ki: f64, kd: f64) -> Self {
Self {
kp,
ki,
kd,
integral: 0.0,
prev_error: 0.0,
integral_min: -100.0,
integral_max: 100.0,
}
}
fn compute(&mut self, error: f64, dt: f64) -> f64 {
let p = self.kp * error;
self.integral += error * dt;
self.integral = self.integral.clamp(self.integral_min, self.integral_max);
let i = self.ki * self.integral;
let d = if dt > 0.0 {
self.kd * (error - self.prev_error) / dt
} else {
0.0
};
self.prev_error = error;
p + i + d
}
fn reset(&mut self) {
self.integral = 0.0;
self.prev_error = 0.0;
}
}
#[derive(Debug, Clone, Copy)]
pub struct AdaptivePoolStats {
pub current_size: usize,
pub adjustments: u64,
pub grows: u64,
pub shrinks: u64,
pub last_hit_rate: f64,
pub last_pressure: MemoryPressureLevel,
}
pub struct AdaptivePoolController<S: BlockStorage + 'static = MmapDiskManager> {
config: AdaptivePoolConfig,
buffer_manager: Arc<BufferManager<S>>,
cache_stats: Arc<CacheStats>,
memory_monitor: Arc<MemoryPressureMonitor>,
current_size: AtomicUsize,
adjustments: AtomicU64,
grows: AtomicU64,
shrinks: AtomicU64,
last_hit_rate: RwLock<f64>,
last_pressure: RwLock<MemoryPressureLevel>,
controller_thread: Option<JoinHandle<()>>,
shutdown: Arc<AtomicBool>,
}
impl<S: BlockStorage + 'static> AdaptivePoolController<S> {
pub fn new(
config: AdaptivePoolConfig,
buffer_manager: Arc<BufferManager<S>>,
cache_stats: Arc<CacheStats>,
memory_monitor: Arc<MemoryPressureMonitor>,
) -> Self {
let initial_size = buffer_manager.pool_size();
Self {
config,
buffer_manager,
cache_stats,
memory_monitor,
current_size: AtomicUsize::new(initial_size),
adjustments: AtomicU64::new(0),
grows: AtomicU64::new(0),
shrinks: AtomicU64::new(0),
last_hit_rate: RwLock::new(1.0),
last_pressure: RwLock::new(MemoryPressureLevel::Normal),
controller_thread: None,
shutdown: Arc::new(AtomicBool::new(false)),
}
}
pub fn start(&mut self) {
if !self.config.enabled {
return;
}
let config = self.config.clone();
let buffer_manager = Arc::clone(&self.buffer_manager);
let cache_stats = Arc::clone(&self.cache_stats);
let memory_monitor = Arc::clone(&self.memory_monitor);
let current_size = Arc::new(AtomicUsize::new(self.current_size.load(Ordering::Relaxed)));
let adjustments = Arc::new(AtomicU64::new(0));
let grows = Arc::new(AtomicU64::new(0));
let shrinks = Arc::new(AtomicU64::new(0));
let shutdown = Arc::clone(&self.shutdown);
let current_size_clone = Arc::clone(¤t_size);
let adjustments_clone = Arc::clone(&adjustments);
let grows_clone = Arc::clone(&grows);
let shrinks_clone = Arc::clone(&shrinks);
self.controller_thread = Some(
thread::Builder::new()
.name("artrie-adaptive-pool".to_string())
.spawn(move || {
Self::control_loop(
config,
buffer_manager,
cache_stats,
memory_monitor,
current_size_clone,
adjustments_clone,
grows_clone,
shrinks_clone,
shutdown,
);
})
.expect("failed to spawn adaptive pool controller thread"),
);
}
pub fn pool_size(&self) -> usize {
self.current_size.load(Ordering::Relaxed)
}
pub fn stats(&self) -> AdaptivePoolStats {
let (last_hit_rate, last_pressure) =
{ (*self.last_hit_rate.read(), *self.last_pressure.read()) };
AdaptivePoolStats {
current_size: self.current_size.load(Ordering::Relaxed),
adjustments: self.adjustments.load(Ordering::Relaxed),
grows: self.grows.load(Ordering::Relaxed),
shrinks: self.shrinks.load(Ordering::Relaxed),
last_hit_rate,
last_pressure,
}
}
pub fn is_running(&self) -> bool {
self.controller_thread.is_some() && !self.shutdown.load(Ordering::Relaxed)
}
pub fn stop(&mut self) {
self.shutdown.store(true, Ordering::Release);
if let Some(handle) = self.controller_thread.take() {
let _ = handle.join();
}
}
fn control_loop(
config: AdaptivePoolConfig,
buffer_manager: Arc<BufferManager<S>>,
cache_stats: Arc<CacheStats>,
memory_monitor: Arc<MemoryPressureMonitor>,
current_size: Arc<AtomicUsize>,
adjustments: Arc<AtomicU64>,
grows: Arc<AtomicU64>,
shrinks: Arc<AtomicU64>,
shutdown: Arc<AtomicBool>,
) {
let mut pid = PidController::new(config.kp, config.ki, config.kd);
let mut last_time = Instant::now();
let mut size = current_size.load(Ordering::Relaxed);
const FRAME_SIZE: usize = 256 * 1024;
while !shutdown.load(Ordering::Relaxed) {
thread::sleep(config.adjustment_interval);
if shutdown.load(Ordering::Relaxed) {
break;
}
let (hit_rate, hits, misses) = cache_stats.get_and_reset();
let memory_stats = memory_monitor.current_stats();
let pressure = memory_monitor.current_level();
let now = Instant::now();
let dt = now.duration_since(last_time).as_secs_f64();
last_time = now;
if hits + misses == 0 {
continue;
}
let available_memory = memory_stats.mem_available as f64;
let memory_target_size =
((available_memory * config.target_memory_fraction) / FRAME_SIZE as f64) as usize;
let hit_rate_error = config.target_hit_rate - hit_rate;
let pid_output = pid.compute(hit_rate_error, dt);
let size_delta = (pid_output * size as f64) as isize;
let new_size = match pressure {
MemoryPressureLevel::Critical => {
pid.reset(); config.min_pool_size
}
MemoryPressureLevel::Low => {
let shrink = (size_delta.min(0).unsigned_abs()).min(config.max_shrink_step);
size.saturating_sub(shrink).max(config.min_pool_size)
}
MemoryPressureLevel::Normal => {
if size_delta > 0 && hit_rate < config.min_hit_rate_for_growth {
let grow = (size_delta as usize).min(config.max_growth_step);
(size + grow)
.min(config.max_pool_size)
.min(memory_target_size)
} else if size_delta < 0 {
let shrink = (size_delta.unsigned_abs()).min(config.max_shrink_step);
size.saturating_sub(shrink).max(config.min_pool_size)
} else {
size
}
}
};
if new_size != size {
adjustments.fetch_add(1, Ordering::Relaxed);
if new_size > size {
let delta = new_size - size;
if buffer_manager.grow_pool(delta).is_ok() {
grows.fetch_add(1, Ordering::Relaxed);
size = new_size;
}
} else {
let delta = size - new_size;
if buffer_manager.shrink_pool(delta).is_ok() {
shrinks.fetch_add(1, Ordering::Relaxed);
size = new_size;
}
}
current_size.store(size, Ordering::Relaxed);
}
}
}
}
impl<S: BlockStorage + 'static> Drop for AdaptivePoolController<S> {
fn drop(&mut self) {
self.stop();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_stats_new() {
let stats = CacheStats::new();
assert_eq!(stats.hit_rate(), 1.0); assert_eq!(stats.total_accesses(), 0);
}
#[test]
fn test_cache_stats_recording() {
let stats = CacheStats::new();
stats.record_hit();
stats.record_hit();
stats.record_miss();
let (hits, misses) = stats.counts();
assert_eq!(hits, 2);
assert_eq!(misses, 1);
assert!((stats.hit_rate() - 0.667).abs() < 0.01);
}
#[test]
fn test_cache_stats_reset() {
let stats = CacheStats::new();
stats.record_hit();
stats.record_miss();
let (hit_rate, hits, misses) = stats.get_and_reset();
assert_eq!(hits, 1);
assert_eq!(misses, 1);
assert!((hit_rate - 0.5).abs() < 0.001);
assert_eq!(stats.total_accesses(), 0);
assert_eq!(stats.hit_rate(), 1.0);
}
#[test]
fn test_pid_controller_proportional() {
let mut pid = PidController::new(1.0, 0.0, 0.0);
let output = pid.compute(0.1, 1.0);
assert!((output - 0.1).abs() < 0.001);
let output = pid.compute(-0.2, 1.0);
assert!((output - (-0.2)).abs() < 0.001);
}
#[test]
fn test_pid_controller_integral() {
let mut pid = PidController::new(0.0, 1.0, 0.0);
let output1 = pid.compute(0.1, 1.0);
assert!((output1 - 0.1).abs() < 0.001);
let output2 = pid.compute(0.1, 1.0);
assert!((output2 - 0.2).abs() < 0.001);
}
#[test]
fn test_pid_controller_derivative() {
let mut pid = PidController::new(0.0, 0.0, 1.0);
let output1 = pid.compute(0.1, 1.0);
assert!((output1 - 0.1).abs() < 0.001);
let output2 = pid.compute(0.1, 1.0);
assert!(output2.abs() < 0.001);
let output3 = pid.compute(0.2, 1.0);
assert!((output3 - 0.1).abs() < 0.001);
}
#[test]
fn test_pid_controller_reset() {
let mut pid = PidController::new(0.0, 1.0, 0.0);
pid.compute(0.1, 1.0);
pid.compute(0.1, 1.0);
pid.reset();
let output = pid.compute(0.1, 1.0);
assert!((output - 0.1).abs() < 0.001);
}
#[test]
fn test_pid_anti_windup() {
let mut pid = PidController::new(0.0, 1.0, 0.0);
for _ in 0..1000 {
pid.compute(10.0, 1.0);
}
let output = pid.compute(0.0, 1.0);
assert!(output <= 100.0);
}
#[test]
fn test_adaptive_pool_config_default() {
let config = AdaptivePoolConfig::default();
assert_eq!(config.min_pool_size, 16);
assert_eq!(config.max_pool_size, 1024);
assert!((config.target_memory_fraction - 0.25).abs() < 0.001);
assert!((config.target_hit_rate - 0.95).abs() < 0.001);
assert!((config.min_hit_rate_for_growth - 0.90).abs() < 0.001);
assert_eq!(config.adjustment_interval, Duration::from_secs(10));
assert_eq!(config.max_growth_step, 16);
assert_eq!(config.max_shrink_step, 8);
assert!(config.enabled);
}
#[test]
fn test_cache_stats_get_and_reset_zero_accesses() {
let stats = CacheStats::new();
let (rate, hits, misses) = stats.get_and_reset();
assert_eq!(rate, 1.0, "Hit rate should be 1.0 when no accesses");
assert_eq!(hits, 0);
assert_eq!(misses, 0);
}
#[test]
fn test_cache_stats_hit_rate_zero_accesses() {
let stats = CacheStats::new();
assert_eq!(
stats.hit_rate(),
1.0,
"Hit rate should be 1.0 when no accesses"
);
assert_eq!(stats.total_accesses(), 0);
}
#[test]
fn test_pid_controller_zero_dt() {
let mut pid = PidController::new(0.5, 0.1, 0.05);
let output = pid.compute(0.1, 0.0);
assert!(
(output - 0.05).abs() < 0.01,
"Output should be ~0.05 with dt=0 (no D term): got {}",
output
);
}
#[test]
fn test_pid_controller_positive_dt() {
let mut pid = PidController::new(0.5, 0.1, 0.05);
let _output1 = pid.compute(0.1, 1.0);
let output2 = pid.compute(0.1, 1.0);
assert!(
output2.abs() < 0.1,
"Output with constant error should be small: got {}",
output2
);
}
#[test]
fn test_pid_controller_derivative_term() {
let mut pid = PidController::new(0.0, 0.0, 1.0);
let output1 = pid.compute(0.1, 1.0);
assert!(
(output1 - 0.1).abs() < 0.001,
"First D output should be 0.1: got {}",
output1
);
let output2 = pid.compute(0.2, 1.0);
assert!(
(output2 - 0.1).abs() < 0.001,
"Second D output should be 0.1: got {}",
output2
);
let output3 = pid.compute(0.2, 1.0);
assert!(
output3.abs() < 0.001,
"Third D output should be ~0: got {}",
output3
);
}
}