fulgurance 0.4.1

A blazing-fast, adaptive prefetching and caching library for Rust
Documentation
//! LSTM-based Prefetch Strategy
//!
//! Simplified Long Short-Term Memory network for temporal dependencies.

use crate::PrefetchStrategy;
use crate::prefetch::neural::KeyConvert;
use std::collections::{HashMap, VecDeque};
use std::marker::PhantomData;

#[derive(Debug, Clone)]
pub struct LSTMPrefetch<K> {
    cell_state: Vec<f32>,
    hidden_state: Vec<f32>,
    forget_weights: Vec<Vec<f32>>,
    input_weights: Vec<Vec<f32>>,
    output_weights: Vec<Vec<f32>>,
    candidate_weights: Vec<Vec<f32>>,
    hidden_size: usize,
    #[allow(dead_code)]
    vocab_size: usize,
    #[allow(dead_code)]
    learning_rate: f32,
    access_history: VecDeque<u64>,
    address_vocab: HashMap<u64, usize>,
    reverse_vocab: HashMap<usize, u64>,
    next_vocab_id: usize,
    max_predictions: usize,
    _marker: PhantomData<K>,
}

impl<K> LSTMPrefetch<K> {
    pub fn new() -> Self {
        Self::with_config(64, 0.01, 3)
    }
    
    pub fn with_config(hidden_size: usize, learning_rate: f32, max_predictions: usize) -> Self {
        let vocab_size = 10000;
        Self {
            cell_state: vec![0.0; hidden_size],
            hidden_state: vec![0.0; hidden_size],
            forget_weights: Self::init_weights(hidden_size + 1, hidden_size),
            input_weights: Self::init_weights(hidden_size + 1, hidden_size),
            output_weights: Self::init_weights(hidden_size + 1, hidden_size),
            candidate_weights: Self::init_weights(hidden_size + 1, hidden_size),
            hidden_size,
            vocab_size,
            learning_rate,
            access_history: VecDeque::with_capacity(100),
            address_vocab: HashMap::new(),
            reverse_vocab: HashMap::new(),
            next_vocab_id: 0,
            max_predictions,
            _marker: PhantomData,
        }
    }
    
    fn init_weights(rows: usize, cols: usize) -> Vec<Vec<f32>> {
        (0..rows)
            .map(|i| (0..cols)

                .map(|j| ((i * cols + j) as f32 * 0.01) % 2.0 - 1.0)
                .collect())
            .collect()
    }
    
    fn sigmoid(x: f32) -> f32 {
        1.0 / (1.0 + (-x.max(-10.0).min(10.0)).exp())
    }
    
    fn tanh(x: f32) -> f32 {
        x.max(-10.0).min(10.0).tanh()
    }
    
    fn get_or_create_vocab_id(&mut self, address: u64) -> usize {
        if let Some(&id) = self.address_vocab.get(&address) { 
            id 
        } else {
            let id = self.next_vocab_id;
            self.address_vocab.insert(address, id);
            self.reverse_vocab.insert(id, address);
            self.next_vocab_id += 1;
            id
        }
    }
    
    fn lstm_forward(&mut self, input_id: usize) -> Vec<f32> {
        let mut input = vec![0.0; self.hidden_size + 1];
        if input_id < self.hidden_size {
            input[input_id] = 1.0;
        }
        let combined: Vec<f32> = input.into_iter()
            .chain(self.hidden_state.iter().cloned())
            .collect();
        
        let forget_gate = (0..self.hidden_size)
            .map(|i| {
                let sum: f32 = combined.iter()
                    .zip(self.forget_weights.iter())
                    .map(|(&x, weights)| x * weights.get(i).unwrap_or(&0.0))
                    .sum();
                Self::sigmoid(sum)
            })
            .collect::<Vec<_>>();
        
        let input_gate = (0..self.hidden_size)
            .map(|i| {
                let sum: f32 = combined.iter()
                    .zip(self.input_weights.iter())
                    .map(|(&x, weights)| x * weights.get(i).unwrap_or(&0.0))
                    .sum();
                Self::sigmoid(sum)
            })
            .collect::<Vec<_>>();
        
        let candidates = (0..self.hidden_size)
            .map(|i| {
                let sum: f32 = combined.iter()
                    .zip(self.candidate_weights.iter())
                    .map(|(&x, weights)| x * weights.get(i).unwrap_or(&0.0))
                    .sum();
                Self::tanh(sum)
            })
            .collect::<Vec<_>>();
        
        for i in 0..self.hidden_size {
            self.cell_state[i] = forget_gate[i] * self.cell_state[i] + input_gate[i] * candidates[i];
        }
        
        let output_gate = (0..self.hidden_size)
            .map(|i| {
                let sum: f32 = combined.iter()
                    .zip(self.output_weights.iter())
                    .map(|(&x, weights)| x * weights.get(i).unwrap_or(&0.0))
                    .sum();
                Self::sigmoid(sum)
            })
            .collect::<Vec<_>>();
        
        for i in 0..self.hidden_size {
            self.hidden_state[i] = output_gate[i] * Self::tanh(self.cell_state[i]);
        }
        
        self.hidden_state.clone()
    }
    
    fn predict_from_hidden(&self) -> Vec<u64> {
        let mut scored_addresses: Vec<(u64, f32)> = self.reverse_vocab.iter()
            .map(|(&vocab_id, &address)| {
                let score = self.hidden_state.iter()
                    .enumerate()
                    .map(|(i, &h)| h * (vocab_id as f32 * 0.01 + i as f32))
                    .sum::<f32>();
                (address, score.abs())
            })
            .collect();
        scored_addresses.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
        scored_addresses.truncate(self.max_predictions);
        scored_addresses.into_iter().map(|(addr, _)| addr).collect()
    }
}

impl<K> PrefetchStrategy<K> for LSTMPrefetch<K>
where
    K: Copy + KeyConvert,
{
    fn predict_next(&mut self, accessed_key: &K) -> Vec<K> {
        if self.access_history.is_empty() {
            let addr = accessed_key.to_u64();
            return vec![K::from_u64(addr + 1), K::from_u64(addr + 2)];
        }
        
        if let Some(&last_addr) = self.access_history.back() {
            let vocab_id = self.get_or_create_vocab_id(last_addr);
            self.lstm_forward(vocab_id);
        }
        
        self.predict_from_hidden()
            .into_iter()
            .map(|addr| K::from_u64(addr))
            .take(self.max_predictions)
            .collect()
    }

    fn update_access_pattern(&mut self, key: &K) {
        let addr = key.to_u64();
        self.access_history.push_back(addr);
        if self.access_history.len() > 100 {
            self.access_history.pop_front();
        }
    }

    fn reset(&mut self) {
        self.cell_state.fill(0.0);
        self.hidden_state.fill(0.0);
        self.access_history.clear();
        self.address_vocab.clear();
        self.reverse_vocab.clear();
        self.next_vocab_id = 0;
    }
}

impl<K> Default for LSTMPrefetch<K>
where
    K: Copy + KeyConvert,
{
    fn default() -> Self {
        Self::new()
    }
}