fulgurance 0.4.1

A blazing-fast, adaptive prefetching and caching library for Rust
Documentation
//! Reinforcement Learning (RL) Prefetch Strategy

use crate::PrefetchStrategy;
use crate::prefetch::neural::KeyConvert;
use std::collections::{HashMap, VecDeque};
use std::marker::PhantomData;

#[derive(Debug)]
pub struct RLPrefetch<K> {
    q_table: HashMap<(u64, i64), f32>,
    current_state: Option<u64>,
    actions: Vec<i64>,
    recent_rewards: VecDeque<f32>,
    episode_count: usize,
    max_predictions: usize,
    _marker: PhantomData<K>,
}

impl<K> RLPrefetch<K> {
    pub fn new() -> Self {
        Self {
            q_table: HashMap::new(),
            current_state: None,
            actions: vec![-2, -1, 1, 2],
            recent_rewards: VecDeque::with_capacity(1000),
            episode_count: 0,
            max_predictions: 3,
            _marker: PhantomData,
        }
    }

    fn select_actions(&self, state: u64) -> Vec<i64> {
        let mut best_actions = Vec::new();
        let mut best_q = f32::MIN;
        
        for &action in &self.actions {
            let q_val = *self.q_table.get(&(state, action)).unwrap_or(&0.0);
            if q_val > best_q {
                best_q = q_val;
                best_actions.clear();
                best_actions.push(action);
            } else if (q_val - best_q).abs() < std::f32::EPSILON {
                best_actions.push(action);
            }
        }
        
        best_actions
    }

    fn compute_reward(&self, predicted_addr: u64, actual_addr: u64) -> f32 {
        if predicted_addr == actual_addr { 
            1.0 
        } else { 
            0.0 
        }
    }

    fn update_q_value(&mut self, state: u64, action: i64, reward: f32, next_state: u64) {
        const ALPHA: f32 = 0.1;
        const GAMMA: f32 = 0.9;
        
        let next_max_q = self.actions.iter()
            .map(|&a| *self.q_table.get(&(next_state, a)).unwrap_or(&0.0))
            .fold(f32::MIN, f32::max);
            
        let old_q = *self.q_table.get(&(state, action)).unwrap_or(&0.0);
        let new_q = old_q + ALPHA * (reward + GAMMA * next_max_q - old_q);
        self.q_table.insert((state, action), new_q);
    }
}

impl<K> PrefetchStrategy<K> for RLPrefetch<K>
where
    K: Copy + KeyConvert,
{
    fn predict_next(&mut self, accessed_key: &K) -> Vec<K> {
        let state = accessed_key.to_u64();
        let actions = self.select_actions(state);
        
        actions.into_iter()
            .map(|action| {
                let predicted_addr = (state as i64).saturating_add(action) as u64;
                K::from_u64(predicted_addr)
            })
            .take(self.max_predictions)
            .collect()
    }

    fn update_access_pattern(&mut self, key: &K) {
        let new_state = key.to_u64();
        
        if let Some(prev_state) = self.current_state {
            let actions = self.actions.clone();
            for action in actions {
                let predicted_addr = (prev_state as i64).saturating_add(action) as u64;
                let reward = self.compute_reward(predicted_addr, new_state);
                self.update_q_value(prev_state, action, reward, new_state);
                self.recent_rewards.push_back(reward);
                if self.recent_rewards.len() > 1000 {
                    self.recent_rewards.pop_front();
                }
            }
        }
        
        self.current_state = Some(new_state);
        self.episode_count += 1;
    }

    fn reset(&mut self) {
        self.q_table.clear();
        self.current_state = None;
        self.recent_rewards.clear();
        self.episode_count = 0;
    }
}

impl<K> Default for RLPrefetch<K>
where
    K: Copy + KeyConvert,
{
    fn default() -> Self {
        Self::new()
    }
}