use super::*;
use crate::error::Result;
use crate::memory::{SecureMemoryPool, SecurePooledPtr};
use std::sync::{Arc, RwLock, Mutex};
use std::sync::atomic::{AtomicU64, AtomicU32, AtomicU16, AtomicU8, Ordering};
use std::collections::HashMap;
use std::io::{Read, Seek, SeekFrom};
use std::fs::File;
pub struct SingleLruPageCache {
config: PageCacheConfig,
page_memory: Arc<SecureMemoryPool>,
page_buffer: SecurePooledPtr,
nodes: Vec<CacheNode>,
hash_table: HashTable,
lru_list: LruList,
file_info: RwLock<HashMap<FileId, Arc<FileInfo>>>,
free_list: AtomicU32,
stats: CacheStatistics,
global_mutex: Mutex<()>,
next_file_id: AtomicU32,
}
impl SingleLruPageCache {
pub fn new(config: PageCacheConfig) -> Result<Self> {
config.validate()?;
let page_count = config.calculate_page_count();
let hash_size = config.calculate_hash_table_size();
let pool_config = if config.memory.use_secure_pools {
crate::memory::SecurePoolConfig::large_secure()
} else {
crate::memory::SecurePoolConfig::large_performance()
};
let page_memory = Arc::new(SecureMemoryPool::new(pool_config)?);
let total_memory = page_count * config.page_size;
let page_buffer = page_memory.allocate_aligned(total_memory, config.memory.alignment)?;
if config.memory.kernel_advice.huge_pages && config.use_huge_pages {
#[cfg(target_os = "linux")]
unsafe {
libc::madvise(
page_buffer.as_ptr() as *mut libc::c_void,
total_memory,
libc::MADV_HUGEPAGE,
);
}
}
if config.memory.kernel_advice.will_need {
#[cfg(target_os = "linux")]
unsafe {
libc::madvise(
page_buffer.as_ptr() as *mut libc::c_void,
total_memory,
libc::MADV_WILLNEED,
);
}
}
let mut nodes = Vec::with_capacity(page_count + 1); nodes.resize_with(page_count + 1, CacheNode::default);
for i in 1..page_count {
nodes[i].hash_link.store((i + 1) as u32, Ordering::Relaxed);
}
nodes[page_count].hash_link.store(INVALID_NODE, Ordering::Relaxed);
let hash_table = HashTable::new(hash_size);
Ok(Self {
config,
page_memory,
page_buffer,
nodes,
hash_table,
lru_list: LruList::new(),
file_info: RwLock::new(HashMap::new()),
free_list: AtomicU32::new(1), stats: CacheStatistics::new(),
global_mutex: Mutex::new(()),
next_file_id: AtomicU32::new(1),
})
}
pub fn register_file(&self, fd: i32) -> Result<FileId> {
let file_id = self.next_file_id.fetch_add(1, Ordering::Relaxed);
let file_info = Arc::new(FileInfo::new(fd));
let mut files = self.file_info.write().map_err(|_| CacheError::AllocationFailed)?;
files.insert(file_id, file_info);
Ok(file_id)
}
pub fn read(&self, file_id: FileId, offset: u64, length: usize, buffer: &mut CacheBuffer) -> Result<&[u8]> {
let page_offset = offset % PAGE_SIZE as u64;
let start_page = offset / PAGE_SIZE as u64;
let end_page = (offset + length as u64 - 1) / PAGE_SIZE as u64;
if start_page == end_page {
self.read_single_page(file_id, start_page as PageId, page_offset as usize, length, buffer)
} else {
self.read_multi_page(file_id, offset, length, buffer)
}
}
fn read_single_page(&self, file_id: FileId, page_id: PageId, offset: usize, length: usize, buffer: &mut CacheBuffer) -> Result<&[u8]> {
if let Some(node_idx) = self.hash_table.find(&self.nodes, file_id, page_id) {
let node = &self.nodes[node_idx as usize];
let old_ref = node.inc_ref();
self.wait_for_page_loaded(node);
node.update_last_access();
self.stats.record_hit(CacheHitType::Hit);
if old_ref == 0 {
self.lru_list.remove(&self.nodes, node_idx);
}
let page_ptr = node.page_data_ptr();
let page_data = unsafe { std::slice::from_raw_parts(page_ptr, PAGE_SIZE) };
buffer.set_node(self, node_idx);
return Ok(&page_data[offset..offset + length]);
}
self.load_page(file_id, page_id, offset, length, buffer)
}
fn read_multi_page(&self, file_id: FileId, offset: u64, length: usize, buffer: &mut CacheBuffer) -> Result<&[u8]> {
let start_page = (offset / PAGE_SIZE as u64) as PageId;
let end_page = ((offset + length as u64 - 1) / PAGE_SIZE as u64) as PageId;
let num_pages = (end_page - start_page + 1) as usize;
let mut page_nodes = Vec::with_capacity(num_pages);
for page_id in start_page..=end_page {
if page_id < end_page {
let bucket_idx = self.hash_table.hash_index(file_id, page_id + 1);
if bucket_idx < self.hash_table.size() {
prefetch_hint(&self.hash_table.buckets[bucket_idx] as *const _ as *const u8);
}
}
let node_idx = if let Some(idx) = self.hash_table.find(&self.nodes, file_id, page_id) {
let node = &self.nodes[idx as usize];
node.inc_ref();
self.wait_for_page_loaded(node);
node.update_last_access();
idx
} else {
self.allocate_and_load_page(file_id, page_id)?
};
page_nodes.push(node_idx);
}
buffer.setup_multi_page(self, page_nodes, offset, length);
self.stats.record_hit(CacheHitType::Mix);
Ok(buffer.data())
}
fn allocate_page(&self) -> Result<NodeIndex> {
let current_free = self.free_list.load(Ordering::Relaxed);
if current_free != INVALID_NODE {
let next_free = self.nodes[current_free as usize].hash_link.load(Ordering::Relaxed);
if self.free_list.compare_exchange_weak(
current_free, next_free, Ordering::Relaxed, Ordering::Relaxed
).is_ok() {
return Ok(current_free);
}
}
self.evict_lru_page()
}
fn evict_lru_page(&self) -> Result<NodeIndex> {
let _lock = self.global_mutex.lock().map_err(|_| CacheError::AllocationFailed)?;
let lru_node = self.lru_list.get_lru_node();
if lru_node == INVALID_NODE {
return Err(CacheError::CacheFull.into());
}
let node = &self.nodes[lru_node as usize];
if node.ref_count() > 0 {
return Err(CacheError::CacheFull.into());
}
let file_id = node.file_id();
let page_id = node.page_id();
if file_id != u32::MAX && page_id != u32::MAX {
self.hash_table.remove(&self.nodes, lru_node, file_id, page_id);
}
self.lru_list.remove(&self.nodes, lru_node);
node.reset();
self.stats.record_hit(CacheHitType::EvictedOthers);
Ok(lru_node)
}
fn load_page(&self, file_id: FileId, page_id: PageId, offset: usize, length: usize, buffer: &mut CacheBuffer) -> Result<&[u8]> {
let node_idx = self.allocate_page()?;
self.allocate_and_load_page_with_node(file_id, page_id, node_idx)?;
let node = &self.nodes[node_idx as usize];
let page_ptr = node.page_data_ptr();
let page_data = unsafe { std::slice::from_raw_parts(page_ptr, PAGE_SIZE) };
buffer.set_node(self, node_idx);
Ok(&page_data[offset..offset + length])
}
fn allocate_and_load_page(&self, file_id: FileId, page_id: PageId) -> Result<NodeIndex> {
let node_idx = self.allocate_page()?;
self.allocate_and_load_page_with_node(file_id, page_id, node_idx)?;
Ok(node_idx)
}
fn allocate_and_load_page_with_node(&self, file_id: FileId, page_id: PageId, node_idx: NodeIndex) -> Result<()> {
let node = &self.nodes[node_idx as usize];
let page_offset = (node_idx as usize - 1) * self.config.page_size;
let page_ptr = unsafe { self.page_buffer.as_ptr().add(page_offset) };
node.initialize(file_id, page_id, page_ptr as *mut u8);
node.inc_ref();
self.hash_table.insert(&self.nodes, node_idx, file_id, page_id);
self.load_page_data(file_id, page_id, page_ptr as *mut u8)?;
node.mark_loaded();
Ok(())
}
fn load_page_data(&self, file_id: FileId, page_id: PageId, page_ptr: *mut u8) -> Result<()> {
let files = self.file_info.read().map_err(|_| CacheError::FileNotFound)?;
let file_info = files.get(&file_id).ok_or(CacheError::FileNotFound)?;
if file_info.is_closed() {
return Err(CacheError::FileNotFound.into());
}
let file_offset = (page_id as u64) * (PAGE_SIZE as u64);
unsafe {
std::ptr::write_bytes(page_ptr, 0, PAGE_SIZE);
}
Ok(())
}
fn wait_for_page_loaded(&self, node: &CacheNode) {
while !node.is_page_loaded() {
std::hint::spin_loop();
}
}
pub fn stats(&self) -> &CacheStatistics {
&self.stats
}
pub fn config(&self) -> &PageCacheConfig {
&self.config
}
}