use crate::PrefetchStrategy;
use crate::prefetch::neural::KeyConvert;
use std::collections::{HashMap, VecDeque};
use std::marker::PhantomData;
#[derive(Debug)]
pub struct RLPrefetch<K> {
q_table: HashMap<(u64, i64), f32>,
current_state: Option<u64>,
actions: Vec<i64>,
recent_rewards: VecDeque<f32>,
episode_count: usize,
max_predictions: usize,
_marker: PhantomData<K>,
}
impl<K> RLPrefetch<K> {
pub fn new() -> Self {
Self {
q_table: HashMap::new(),
current_state: None,
actions: vec![-2, -1, 1, 2],
recent_rewards: VecDeque::with_capacity(1000),
episode_count: 0,
max_predictions: 3,
_marker: PhantomData,
}
}
fn select_actions(&self, state: u64) -> Vec<i64> {
let mut best_actions = Vec::new();
let mut best_q = f32::MIN;
for &action in &self.actions {
let q_val = *self.q_table.get(&(state, action)).unwrap_or(&0.0);
if q_val > best_q {
best_q = q_val;
best_actions.clear();
best_actions.push(action);
} else if (q_val - best_q).abs() < std::f32::EPSILON {
best_actions.push(action);
}
}
best_actions
}
fn compute_reward(&self, predicted_addr: u64, actual_addr: u64) -> f32 {
if predicted_addr == actual_addr {
1.0
} else {
0.0
}
}
fn update_q_value(&mut self, state: u64, action: i64, reward: f32, next_state: u64) {
const ALPHA: f32 = 0.1;
const GAMMA: f32 = 0.9;
let next_max_q = self.actions.iter()
.map(|&a| *self.q_table.get(&(next_state, a)).unwrap_or(&0.0))
.fold(f32::MIN, f32::max);
let old_q = *self.q_table.get(&(state, action)).unwrap_or(&0.0);
let new_q = old_q + ALPHA * (reward + GAMMA * next_max_q - old_q);
self.q_table.insert((state, action), new_q);
}
}
impl<K> PrefetchStrategy<K> for RLPrefetch<K>
where
K: Copy + KeyConvert,
{
fn predict_next(&mut self, accessed_key: &K) -> Vec<K> {
let state = accessed_key.to_u64();
let actions = self.select_actions(state);
actions.into_iter()
.map(|action| {
let predicted_addr = (state as i64).saturating_add(action) as u64;
K::from_u64(predicted_addr)
})
.take(self.max_predictions)
.collect()
}
fn update_access_pattern(&mut self, key: &K) {
let new_state = key.to_u64();
if let Some(prev_state) = self.current_state {
let actions = self.actions.clone();
for action in actions {
let predicted_addr = (prev_state as i64).saturating_add(action) as u64;
let reward = self.compute_reward(predicted_addr, new_state);
self.update_q_value(prev_state, action, reward, new_state);
self.recent_rewards.push_back(reward);
if self.recent_rewards.len() > 1000 {
self.recent_rewards.pop_front();
}
}
}
self.current_state = Some(new_state);
self.episode_count += 1;
}
fn reset(&mut self) {
self.q_table.clear();
self.current_state = None;
self.recent_rewards.clear();
self.episode_count = 0;
}
}
impl<K> Default for RLPrefetch<K>
where
K: Copy + KeyConvert,
{
fn default() -> Self {
Self::new()
}
}