use lru::LruCache;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::num::NonZeroUsize;
use std::sync::Arc;
use crate::{types::TableId, Config, Result, Value};
#[derive(Debug)]
pub struct MemoryManager {
block_cache: Arc<RwLock<BlockCache>>,
row_cache: Arc<RwLock<RowCache>>,
buffer_pool: Arc<RwLock<BufferPool>>,
stats: Arc<RwLock<MemoryStats>>,
}
struct BlockCache {
cache: LruCache<BlockKey, Arc<Block>>,
max_size: usize,
current_size: usize,
}
struct RowCache {
cache: LruCache<RowKey, Arc<CachedRow>>,
max_size: usize,
current_size: usize,
}
#[derive(Debug)]
struct BufferPool {
free_buffers: HashMap<usize, Vec<Vec<u8>>>,
allocated_count: usize,
total_memory: usize,
max_memory: usize,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
struct BlockKey {
table_id: TableId,
block_id: u64,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
struct RowKey {
table_id: TableId,
row_key: String,
}
#[derive(Debug)]
struct Block {
size: usize,
_last_access: std::time::Instant,
}
#[derive(Debug)]
struct CachedRow {
_data: Vec<Value>,
size: usize,
}
impl MemoryManager {
pub fn new(config: &Config) -> Result<Self> {
let block_cache = Arc::new(RwLock::new(BlockCache::new(
config.memory.block_cache.max_size as usize,
)));
let row_cache = Arc::new(RwLock::new(RowCache::new(
config.memory.row_cache.max_size as usize,
)));
let buffer_pool = Arc::new(RwLock::new(BufferPool::new(
config.memory.max_memory as usize,
)));
Ok(Self {
block_cache,
row_cache,
buffer_pool,
stats: Arc::new(RwLock::new(MemoryStats::default())),
})
}
pub fn get_block(&self, table_id: &TableId, block_id: u64) -> Option<Arc<Block>> {
let key = BlockKey {
table_id: table_id.clone(),
block_id,
};
let mut cache = self.block_cache.write();
if let Some(block) = cache.cache.get(&key) {
{
let mut stats = self.stats.write();
stats.block_cache_hits += 1;
}
Some(Arc::clone(block))
} else {
{
let mut stats = self.stats.write();
stats.block_cache_misses += 1;
}
None
}
}
pub fn put_block(&self, table_id: &TableId, block_id: u64, data: Vec<u8>) {
let key = BlockKey {
table_id: table_id.clone(),
block_id,
};
let block = Arc::new(Block {
size: data.len(),
_last_access: std::time::Instant::now(),
});
let mut cache = self.block_cache.write();
while cache.current_size + block.size > cache.max_size {
if let Some((_, evicted_block)) = cache.cache.pop_lru() {
cache.current_size -= evicted_block.size;
} else {
break;
}
}
cache.current_size += block.size;
cache.cache.put(key, block);
}
pub fn get_row(&self, table_id: &TableId, row_key: &str) -> Option<Arc<CachedRow>> {
let key = RowKey {
table_id: table_id.clone(),
row_key: row_key.to_string(),
};
let mut cache = self.row_cache.write();
if let Some(row) = cache.cache.get(&key) {
{
let mut stats = self.stats.write();
stats.row_cache_hits += 1;
}
Some(Arc::clone(row))
} else {
{
let mut stats = self.stats.write();
stats.row_cache_misses += 1;
}
None
}
}
pub fn put_row(&self, table_id: &TableId, row_key: &str, data: Vec<Value>) {
let key = RowKey {
table_id: table_id.clone(),
row_key: row_key.to_string(),
};
let size = self.estimate_row_size(&data);
let row = Arc::new(CachedRow { _data: data, size });
let mut cache = self.row_cache.write();
while cache.current_size + row.size > cache.max_size {
if let Some((_, evicted_row)) = cache.cache.pop_lru() {
cache.current_size -= evicted_row.size;
} else {
break;
}
}
cache.current_size += row.size;
cache.cache.put(key, row);
}
pub fn allocate_buffer(&self, size: usize) -> Result<Vec<u8>> {
let mut pool = self.buffer_pool.write();
if let Some(buffers) = pool.free_buffers.get_mut(&size) {
if let Some(buffer) = buffers.pop() {
pool.allocated_count += 1;
pool.total_memory += size;
let mut stats = self.stats.write();
stats.buffer_allocations += 1;
stats.total_memory_used = pool.total_memory;
return Ok(buffer);
}
}
if pool.total_memory + size > pool.max_memory {
return Err(crate::Error::Memory(format!(
"Memory limit exceeded: requested {} bytes would exceed limit of {} bytes (current usage: {} bytes)",
size, pool.max_memory, pool.total_memory
)));
}
pool.allocated_count += 1;
pool.total_memory += size;
let mut stats = self.stats.write();
stats.buffer_allocations += 1;
stats.total_memory_used = pool.total_memory;
Ok(vec![0u8; size])
}
pub fn deallocate_buffer(&self, mut buffer: Vec<u8>) {
let size = buffer.len();
buffer.clear();
buffer.resize(size, 0);
let mut pool = self.buffer_pool.write();
pool.total_memory -= size;
pool.free_buffers.entry(size).or_default().push(buffer);
pool.allocated_count -= 1;
let mut stats = self.stats.write();
stats.buffer_deallocations += 1;
stats.total_memory_used = pool.total_memory;
}
pub fn stats(&self) -> Result<MemoryStats> {
let stats = self.stats.read();
Ok(stats.clone())
}
pub fn clear_caches(&self) {
{
let mut cache = self.block_cache.write();
cache.cache.clear();
cache.current_size = 0;
}
{
let mut cache = self.row_cache.write();
cache.cache.clear();
cache.current_size = 0;
}
}
fn estimate_row_size(&self, data: &[Value]) -> usize {
data.iter().map(|v| self.estimate_value_size(v)).sum()
}
#[allow(clippy::only_used_in_recursion)]
fn estimate_value_size(&self, value: &Value) -> usize {
match value {
Value::Null => 1,
Value::Boolean(_) => 1,
Value::Integer(_) => 4,
Value::BigInt(_) => 8,
Value::Counter(_) => 8,
Value::Float(_) => 8,
Value::Text(s) => s.len(),
Value::Blob(b) => b.len(),
Value::Timestamp(_) => 8,
Value::Date(_) => 4,
Value::Time(_) => 8,
Value::Uuid(_) => 16,
Value::Inet(bytes) => bytes.len(),
Value::Json(json) => json.to_string().len(),
Value::List(items) => items.iter().map(|v| self.estimate_value_size(v)).sum(),
Value::Map(map) => map
.iter()
.map(|(k, v)| self.estimate_value_size(k) + self.estimate_value_size(v))
.sum(),
Value::TinyInt(_) => 1,
Value::SmallInt(_) => 2,
Value::Float32(_) => 4,
Value::Set(items) => items.iter().map(|v| self.estimate_value_size(v)).sum(),
Value::Tuple(items) => items.iter().map(|v| self.estimate_value_size(v)).sum(),
Value::Udt(udt) => udt
.fields
.iter()
.map(|f| f.value.as_ref().map_or(0, |v| self.estimate_value_size(v)))
.sum(),
Value::Frozen(boxed_value) => self.estimate_value_size(boxed_value),
Value::Varint(data) => data.len(),
Value::Decimal { unscaled, .. } => 4 + unscaled.len(), Value::Duration { .. } => 12, Value::Tombstone(_) => 16, }
}
}
impl std::fmt::Debug for BlockCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BlockCache")
.field("max_size", &self.max_size)
.field("current_size", &self.current_size)
.field("cache_len", &self.cache.len())
.finish()
}
}
impl BlockCache {
fn new(max_size: usize) -> Self {
let capacity = NonZeroUsize::new(1000).expect("capacity must be non-zero");
Self {
cache: LruCache::new(capacity),
max_size,
current_size: 0,
}
}
}
impl std::fmt::Debug for RowCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RowCache")
.field("max_size", &self.max_size)
.field("current_size", &self.current_size)
.field("cache_len", &self.cache.len())
.finish()
}
}
impl RowCache {
fn new(max_size: usize) -> Self {
let capacity = NonZeroUsize::new(1000).expect("capacity must be non-zero");
Self {
cache: LruCache::new(capacity),
max_size,
current_size: 0,
}
}
}
impl BufferPool {
fn new(max_memory: usize) -> Self {
Self {
free_buffers: HashMap::new(),
allocated_count: 0,
total_memory: 0,
max_memory,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct MemoryStats {
pub block_cache_hits: u64,
pub block_cache_misses: u64,
pub row_cache_hits: u64,
pub row_cache_misses: u64,
pub total_memory_used: usize,
pub buffer_allocations: u64,
pub buffer_deallocations: u64,
}
impl MemoryStats {
pub fn block_cache_hit_rate(&self) -> f64 {
let total = self.block_cache_hits + self.block_cache_misses;
if total > 0 {
self.block_cache_hits as f64 / total as f64
} else {
0.0
}
}
pub fn row_cache_hit_rate(&self) -> f64 {
let total = self.row_cache_hits + self.row_cache_misses;
if total > 0 {
self.row_cache_hits as f64 / total as f64
} else {
0.0
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::TableId;
#[test]
fn test_memory_manager_creation() {
let config = Config::default();
let manager = MemoryManager::new(&config).unwrap();
let stats = manager.stats().unwrap();
assert_eq!(stats.block_cache_hits, 0);
assert_eq!(stats.block_cache_misses, 0);
}
#[test]
fn test_block_cache() {
let config = Config::default();
let manager = MemoryManager::new(&config).unwrap();
let table_id = TableId::new("test_table");
let block_id = 1;
let data = vec![1, 2, 3, 4, 5];
let result = manager.get_block(&table_id, block_id);
assert!(result.is_none());
manager.put_block(&table_id, block_id, data.clone());
let result = manager.get_block(&table_id, block_id);
assert!(result.is_some());
assert_eq!(result.unwrap().size, data.len());
}
#[test]
fn test_block_cache_eviction_updates_stats() {
let mut config = Config::default();
config.memory.block_cache.max_size = 8;
let manager = MemoryManager::new(&config).unwrap();
let table_id = TableId::new("ks_table");
manager.put_block(&table_id, 1, vec![0u8; 8]);
manager.put_block(&table_id, 2, vec![0u8; 4]);
assert!(manager.get_block(&table_id, 1).is_none());
assert!(manager.get_block(&table_id, 2).is_some());
let stats = manager.stats().unwrap();
assert_eq!(stats.block_cache_hits, 1);
assert_eq!(stats.block_cache_misses, 1);
}
#[test]
fn test_row_cache() {
let config = Config::default();
let manager = MemoryManager::new(&config).unwrap();
let table_id = TableId::new("test_table");
let row_key = "test_key";
let data = vec![Value::Integer(42), Value::Text("hello".to_string())];
let result = manager.get_row(&table_id, row_key);
assert!(result.is_none());
manager.put_row(&table_id, row_key, data.clone());
let result = manager.get_row(&table_id, row_key);
assert!(result.is_some());
assert_eq!(result.unwrap()._data, data);
}
#[test]
fn test_row_cache_eviction_and_stats() {
let mut config = Config::default();
config.memory.row_cache.max_size = 8;
let manager = MemoryManager::new(&config).unwrap();
let table_id = TableId::new("ks_table");
manager.put_row(&table_id, "k1", vec![Value::Text("abcd".into())]);
manager.put_row(&table_id, "k2", vec![Value::Text("efgh".into())]);
manager.put_row(&table_id, "k3", vec![Value::Text("ijkl".into())]);
assert!(manager.get_row(&table_id, "k1").is_none());
assert!(manager.get_row(&table_id, "k3").is_some());
let stats = manager.stats().unwrap();
assert_eq!(stats.row_cache_hits, 1);
assert_eq!(stats.row_cache_misses, 1);
}
#[test]
fn test_buffer_pool() {
let config = Config::default();
let manager = MemoryManager::new(&config).unwrap();
let size = 1024;
let buffer = manager.allocate_buffer(size).unwrap();
assert_eq!(buffer.len(), size);
manager.deallocate_buffer(buffer);
let buffer2 = manager.allocate_buffer(size).unwrap();
assert_eq!(buffer2.len(), size);
}
#[test]
fn test_clear_caches() {
let mut config = Config::default();
config.memory.block_cache.max_size = 8;
config.memory.row_cache.max_size = 8;
let manager = MemoryManager::new(&config).unwrap();
let table_id = TableId::new("ks_table");
manager.put_block(&table_id, 1, vec![0u8; 8]);
manager.put_row(&table_id, "k1", vec![Value::Text("abcd".into())]);
manager.clear_caches();
assert!(manager.get_block(&table_id, 1).is_none());
assert!(manager.get_row(&table_id, "k1").is_none());
}
#[test]
fn test_memory_limit_enforcement() {
let mut config = Config::default();
config.memory.max_memory = 128 * 1024 * 1024; let manager = MemoryManager::new(&config).unwrap();
let buffer1 = manager
.allocate_buffer(64 * 1024 * 1024)
.expect("first 64MB should succeed");
let buffer2 = manager
.allocate_buffer(64 * 1024 * 1024)
.expect("second 64MB should succeed");
let result = manager.allocate_buffer(1024);
assert!(result.is_err(), "allocation exceeding limit should fail");
if let Err(e) = result {
let err_msg = e.to_string();
assert!(
err_msg.contains("Memory limit exceeded"),
"error should mention memory limit"
);
}
let stats = manager.stats().unwrap();
assert_eq!(
stats.buffer_allocations, 2,
"should have 2 successful allocations"
);
assert_eq!(
stats.total_memory_used,
128 * 1024 * 1024,
"should be at memory limit"
);
manager.deallocate_buffer(buffer1);
let stats = manager.stats().unwrap();
assert_eq!(stats.buffer_deallocations, 1);
assert_eq!(
stats.total_memory_used,
64 * 1024 * 1024,
"memory should be freed"
);
let buffer3 = manager
.allocate_buffer(32 * 1024 * 1024)
.expect("allocation after free should succeed");
manager.deallocate_buffer(buffer2);
manager.deallocate_buffer(buffer3);
let final_stats = manager.stats().unwrap();
assert_eq!(
final_stats.total_memory_used, 0,
"all memory should be freed"
);
}
#[test]
fn test_memory_limit_with_buffer_reuse() {
let mut config = Config::default();
config.memory.max_memory = 128 * 1024 * 1024; let manager = MemoryManager::new(&config).unwrap();
let buffer1 = manager
.allocate_buffer(64 * 1024 * 1024)
.expect("first 64MB should succeed");
let buffer2 = manager
.allocate_buffer(64 * 1024 * 1024)
.expect("second 64MB should succeed");
manager.deallocate_buffer(buffer1);
let stats = manager.stats().unwrap();
assert_eq!(
stats.total_memory_used,
64 * 1024 * 1024,
"should have 64MB in use after deallocation"
);
let buffer3 = manager
.allocate_buffer(64 * 1024 * 1024)
.expect("reuse should succeed");
let stats = manager.stats().unwrap();
assert_eq!(
stats.total_memory_used,
128 * 1024 * 1024,
"reused buffer should count toward memory limit"
);
let result = manager.allocate_buffer(1024);
assert!(
result.is_err(),
"allocation should fail when limit reached via buffer reuse"
);
if let Err(e) = result {
let err_msg = e.to_string();
assert!(
err_msg.contains("Memory limit exceeded"),
"error should mention memory limit"
);
}
manager.deallocate_buffer(buffer2);
manager.deallocate_buffer(buffer3);
let final_stats = manager.stats().unwrap();
assert_eq!(final_stats.total_memory_used, 0, "all memory freed");
}
}