use alloc::collections::BTreeMap;
use alloc::vec::Vec;
use core::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use lazy_static::lazy_static;
use spin::Mutex;
use super::mlprefetch::{MlPrefetchEngine, PrefetchStats};
use crate::cache::arc::{ARC, Arc as ArcCache, SHARDED_ARC};
use crate::fscore::structs::Dva;
const MAX_PENDING_PREFETCHES: usize = 64;
const MIN_HIT_RATE: f32 = 0.10;
const AGGRESSIVE_HIT_RATE: f32 = 0.50;
const DEFAULT_PREFETCH_PRIORITY: u8 = 25;
const HIGH_PREFETCH_PRIORITY: u8 = 75;
#[derive(Debug, Clone)]
pub struct PrefetchRequest {
pub dva: Dva,
pub size: u64,
pub priority: u8,
pub use_l2arc: bool,
pub timestamp: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PrefetchResult {
Success,
AlreadyCached,
QueueFull,
Disabled,
IoError,
}
pub static PREFETCH_ENABLED: AtomicBool = AtomicBool::new(true);
pub static PREFETCH_SUCCESS: AtomicU64 = AtomicU64::new(0);
pub static PREFETCH_ALREADY_CACHED: AtomicU64 = AtomicU64::new(0);
pub static PREFETCH_QUEUE_FULL: AtomicU64 = AtomicU64::new(0);
pub static PREFETCH_IO_ERRORS: AtomicU64 = AtomicU64::new(0);
pub static PREFETCH_L2ARC_COUNT: AtomicU64 = AtomicU64::new(0);
lazy_static! {
static ref PREFETCH_ADAPTER: Mutex<MlPrefetchAdapter> = Mutex::new(MlPrefetchAdapter::new());
static ref PREFETCH_QUEUE: Mutex<BTreeMap<Dva, PrefetchRequest>> = Mutex::new(BTreeMap::new());
static ref FILE_ENGINES: Mutex<BTreeMap<u64, MlPrefetchEngine>> = Mutex::new(BTreeMap::new());
}
pub struct MlPrefetchAdapter {
engine: MlPrefetchEngine,
timestamp: u64,
hit_rate_ema: f32,
distance_multiplier: f32,
}
impl Default for MlPrefetchAdapter {
fn default() -> Self {
Self::new()
}
}
impl MlPrefetchAdapter {
pub fn new() -> Self {
Self {
engine: MlPrefetchEngine::new(),
timestamp: 0,
hit_rate_ema: 0.5, distance_multiplier: 1.0,
}
}
pub fn record_read(&mut self, dva: Dva, size: u64, file_id: Option<u64>) -> usize {
if !PREFETCH_ENABLED.load(Ordering::Relaxed) {
return 0;
}
self.timestamp += 1;
let offset = dva.offset;
let predictions = self
.engine
.record_and_predict(offset, size, self.timestamp, true);
if let Some(fid) = file_id {
let mut file_engines = FILE_ENGINES.lock();
let file_engine = file_engines.entry(fid).or_default();
let file_predictions =
file_engine.record_and_predict(offset, size, self.timestamp, true);
return self.issue_prefetches(file_predictions, dva, size);
}
self.issue_prefetches(predictions, dva, size)
}
pub fn record_write(&mut self, dva: Dva, size: u64, file_id: Option<u64>) {
self.timestamp += 1;
let offset = dva.offset;
let _ = self
.engine
.record_and_predict(offset, size, self.timestamp, false);
if let Some(fid) = file_id {
let mut file_engines = FILE_ENGINES.lock();
let file_engine = file_engines.entry(fid).or_default();
let _ = file_engine.record_and_predict(offset, size, self.timestamp, false);
}
}
fn issue_prefetches(
&mut self,
predictions: Vec<(u64, u64)>,
source_dva: Dva,
_size: u64,
) -> usize {
let mut issued = 0;
let mut queue = PREFETCH_QUEUE.lock();
self.update_hit_rate();
if self.hit_rate_ema < MIN_HIT_RATE && self.timestamp > 100 {
return 0;
}
self.distance_multiplier = if self.hit_rate_ema > AGGRESSIVE_HIT_RATE {
1.5 } else if self.hit_rate_ema > MIN_HIT_RATE {
1.0 } else {
0.5 };
for (offset, pred_size) in predictions {
if queue.len() >= MAX_PENDING_PREFETCHES {
PREFETCH_QUEUE_FULL.fetch_add(1, Ordering::Relaxed);
break;
}
let prefetch_dva = Dva {
vdev: source_dva.vdev,
offset,
};
if queue.contains_key(&prefetch_dva) {
continue;
}
if is_in_arc(&prefetch_dva) {
PREFETCH_ALREADY_CACHED.fetch_add(1, Ordering::Relaxed);
continue;
}
let priority = if self.hit_rate_ema > 0.7 {
HIGH_PREFETCH_PRIORITY
} else {
DEFAULT_PREFETCH_PRIORITY
};
let request = PrefetchRequest {
dva: prefetch_dva,
size: pred_size,
priority,
use_l2arc: should_use_l2arc(&prefetch_dva),
timestamp: self.timestamp,
};
queue.insert(prefetch_dva, request);
issued += 1;
}
self.process_prefetch_queue(&mut queue);
issued
}
fn update_hit_rate(&mut self) {
let stats = self.engine.get_stats();
let total = stats.prefetch_hits + stats.prefetch_misses;
if total > 0 {
let current_rate = stats.prefetch_hits as f32 / total as f32;
self.hit_rate_ema = 0.9 * self.hit_rate_ema + 0.1 * current_rate;
}
}
fn process_prefetch_queue(&self, queue: &mut BTreeMap<Dva, PrefetchRequest>) {
let mut to_remove = Vec::new();
for (dva, request) in queue.iter() {
let result = execute_prefetch(request);
match result {
PrefetchResult::Success => {
PREFETCH_SUCCESS.fetch_add(1, Ordering::Relaxed);
if request.use_l2arc {
PREFETCH_L2ARC_COUNT.fetch_add(1, Ordering::Relaxed);
}
to_remove.push(*dva);
}
PrefetchResult::AlreadyCached => {
PREFETCH_ALREADY_CACHED.fetch_add(1, Ordering::Relaxed);
to_remove.push(*dva);
}
PrefetchResult::IoError => {
PREFETCH_IO_ERRORS.fetch_add(1, Ordering::Relaxed);
to_remove.push(*dva);
}
PrefetchResult::QueueFull | PrefetchResult::Disabled => {
}
}
}
for dva in to_remove {
queue.remove(&dva);
}
}
pub fn get_hit_rate(&self) -> f32 {
self.hit_rate_ema
}
pub fn get_stats(&self) -> AdapterStats {
let engine_stats = self.engine.get_stats();
AdapterStats {
ml_stats: engine_stats,
prefetch_success: PREFETCH_SUCCESS.load(Ordering::Relaxed),
prefetch_already_cached: PREFETCH_ALREADY_CACHED.load(Ordering::Relaxed),
prefetch_queue_full: PREFETCH_QUEUE_FULL.load(Ordering::Relaxed),
prefetch_io_errors: PREFETCH_IO_ERRORS.load(Ordering::Relaxed),
prefetch_l2arc: PREFETCH_L2ARC_COUNT.load(Ordering::Relaxed),
hit_rate_ema: self.hit_rate_ema,
distance_multiplier: self.distance_multiplier,
pending_queue_size: PREFETCH_QUEUE.lock().len() as u64,
}
}
pub fn expire_old(&mut self) {
self.engine.expire_prefetches();
let mut queue = PREFETCH_QUEUE.lock();
let cutoff = self.timestamp.saturating_sub(1000);
queue.retain(|_, req| req.timestamp > cutoff);
}
}
#[derive(Debug, Clone)]
pub struct AdapterStats {
pub ml_stats: PrefetchStats,
pub prefetch_success: u64,
pub prefetch_already_cached: u64,
pub prefetch_queue_full: u64,
pub prefetch_io_errors: u64,
pub prefetch_l2arc: u64,
pub hit_rate_ema: f32,
pub distance_multiplier: f32,
pub pending_queue_size: u64,
}
fn is_in_arc(dva: &Dva) -> bool {
let arc = ARC.lock();
if arc.index.contains_key(dva) {
return true;
}
drop(arc);
let shard_idx = dva.offset as usize % SHARDED_ARC.len();
let shard = SHARDED_ARC[shard_idx].lock();
shard.index.contains_key(dva)
}
fn should_use_l2arc(_dva: &Dva) -> bool {
let arc = ARC.lock();
let arc_pressure = arc.current_size as f64 / arc.max_bytes.max(1) as f64;
drop(arc);
arc_pressure > 0.8
}
fn execute_prefetch(request: &PrefetchRequest) -> PrefetchResult {
if !PREFETCH_ENABLED.load(Ordering::Relaxed) {
return PrefetchResult::Disabled;
}
if is_in_arc(&request.dva) {
return PrefetchResult::AlreadyCached;
}
let shard_idx = request.dva.offset as usize % SHARDED_ARC.len();
let _shard = SHARDED_ARC[shard_idx].lock();
PrefetchResult::Success
}
pub fn record_read(dva: Dva, size: u64, file_id: Option<u64>) -> usize {
let mut adapter = PREFETCH_ADAPTER.lock();
adapter.record_read(dva, size, file_id)
}
pub fn record_write(dva: Dva, size: u64, file_id: Option<u64>) {
let mut adapter = PREFETCH_ADAPTER.lock();
adapter.record_write(dva, size, file_id);
}
pub fn get_hit_rate() -> f32 {
let adapter = PREFETCH_ADAPTER.lock();
adapter.get_hit_rate()
}
pub fn get_stats() -> AdapterStats {
let adapter = PREFETCH_ADAPTER.lock();
adapter.get_stats()
}
pub fn set_enabled(enabled: bool) {
PREFETCH_ENABLED.store(enabled, Ordering::Relaxed);
}
pub fn is_enabled() -> bool {
PREFETCH_ENABLED.load(Ordering::Relaxed)
}
pub fn expire_old() {
let mut adapter = PREFETCH_ADAPTER.lock();
adapter.expire_old();
}
pub fn pending_count() -> usize {
PREFETCH_QUEUE.lock().len()
}
pub fn clear_pending() {
PREFETCH_QUEUE.lock().clear();
}
pub fn clear_file_engine(file_id: u64) {
FILE_ENGINES.lock().remove(&file_id);
}
pub fn clear_all_file_engines() {
FILE_ENGINES.lock().clear();
}
#[cfg(test)]
mod tests {
use super::*;
fn reset_stats() {
PREFETCH_SUCCESS.store(0, Ordering::Relaxed);
PREFETCH_ALREADY_CACHED.store(0, Ordering::Relaxed);
PREFETCH_QUEUE_FULL.store(0, Ordering::Relaxed);
PREFETCH_IO_ERRORS.store(0, Ordering::Relaxed);
PREFETCH_L2ARC_COUNT.store(0, Ordering::Relaxed);
}
#[test]
fn test_adapter_creation() {
let adapter = MlPrefetchAdapter::new();
assert_eq!(adapter.timestamp, 0);
assert!(adapter.hit_rate_ema > 0.0);
}
#[test]
fn test_record_read_generates_prefetches() {
reset_stats();
let mut adapter = MlPrefetchAdapter::new();
for i in 0..10 {
let dva = Dva {
vdev: 0,
offset: i * 4096,
};
adapter.record_read(dva, 4096, None);
}
let stats = adapter.get_stats();
assert!(stats.ml_stats.total_prefetches > 0);
}
#[test]
fn test_write_no_prefetch() {
reset_stats();
let mut adapter = MlPrefetchAdapter::new();
for i in 0..10 {
let dva = Dva {
vdev: 0,
offset: i * 4096,
};
adapter.record_write(dva, 4096, None);
}
let stats = adapter.get_stats();
assert_eq!(stats.ml_stats.total_prefetches, 0);
}
#[test]
fn test_per_file_engines() {
reset_stats();
for file_id in 0..3 {
for i in 0..5 {
let dva = Dva {
vdev: 0,
offset: i * 4096,
};
record_read(dva, 4096, Some(file_id));
}
}
let engines = FILE_ENGINES.lock();
assert_eq!(engines.len(), 3);
}
#[test]
fn test_enable_disable() {
set_enabled(false);
assert!(!is_enabled());
set_enabled(true);
assert!(is_enabled());
}
#[test]
fn test_pending_queue() {
clear_pending();
assert_eq!(pending_count(), 0);
}
#[test]
fn test_stats_structure() {
let stats = get_stats();
let _ = stats.ml_stats;
let _ = stats.prefetch_success;
let _ = stats.prefetch_already_cached;
let _ = stats.prefetch_queue_full;
let _ = stats.prefetch_io_errors;
let _ = stats.prefetch_l2arc;
let _ = stats.hit_rate_ema;
let _ = stats.distance_multiplier;
let _ = stats.pending_queue_size;
}
#[test]
fn test_hit_rate_calculation() {
reset_stats();
let adapter = MlPrefetchAdapter::new();
assert!(adapter.get_hit_rate() > 0.0);
assert!(adapter.get_hit_rate() <= 1.0);
}
#[test]
fn test_clear_file_engine() {
let dva = Dva { vdev: 0, offset: 0 };
record_read(dva, 4096, Some(999));
{
let engines = FILE_ENGINES.lock();
assert!(engines.contains_key(&999));
}
clear_file_engine(999);
{
let engines = FILE_ENGINES.lock();
assert!(!engines.contains_key(&999));
}
}
}