use crate::PrefetchStrategy;
use crate::prefetch::neural::KeyConvert;
use std::collections::{HashMap, VecDeque};
use std::marker::PhantomData;
#[derive(Debug, Clone)]
pub struct LSTMPrefetch<K> {
cell_state: Vec<f32>,
hidden_state: Vec<f32>,
forget_weights: Vec<Vec<f32>>,
input_weights: Vec<Vec<f32>>,
output_weights: Vec<Vec<f32>>,
candidate_weights: Vec<Vec<f32>>,
hidden_size: usize,
#[allow(dead_code)]
vocab_size: usize,
#[allow(dead_code)]
learning_rate: f32,
access_history: VecDeque<u64>,
address_vocab: HashMap<u64, usize>,
reverse_vocab: HashMap<usize, u64>,
next_vocab_id: usize,
max_predictions: usize,
_marker: PhantomData<K>,
}
impl<K> LSTMPrefetch<K> {
pub fn new() -> Self {
Self::with_config(64, 0.01, 3)
}
pub fn with_config(hidden_size: usize, learning_rate: f32, max_predictions: usize) -> Self {
let vocab_size = 10000;
Self {
cell_state: vec![0.0; hidden_size],
hidden_state: vec![0.0; hidden_size],
forget_weights: Self::init_weights(hidden_size + 1, hidden_size),
input_weights: Self::init_weights(hidden_size + 1, hidden_size),
output_weights: Self::init_weights(hidden_size + 1, hidden_size),
candidate_weights: Self::init_weights(hidden_size + 1, hidden_size),
hidden_size,
vocab_size,
learning_rate,
access_history: VecDeque::with_capacity(100),
address_vocab: HashMap::new(),
reverse_vocab: HashMap::new(),
next_vocab_id: 0,
max_predictions,
_marker: PhantomData,
}
}
fn init_weights(rows: usize, cols: usize) -> Vec<Vec<f32>> {
(0..rows)
.map(|i| (0..cols)
.map(|j| ((i * cols + j) as f32 * 0.01) % 2.0 - 1.0)
.collect())
.collect()
}
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x.max(-10.0).min(10.0)).exp())
}
fn tanh(x: f32) -> f32 {
x.max(-10.0).min(10.0).tanh()
}
fn get_or_create_vocab_id(&mut self, address: u64) -> usize {
if let Some(&id) = self.address_vocab.get(&address) {
id
} else {
let id = self.next_vocab_id;
self.address_vocab.insert(address, id);
self.reverse_vocab.insert(id, address);
self.next_vocab_id += 1;
id
}
}
fn lstm_forward(&mut self, input_id: usize) -> Vec<f32> {
let mut input = vec![0.0; self.hidden_size + 1];
if input_id < self.hidden_size {
input[input_id] = 1.0;
}
let combined: Vec<f32> = input.into_iter()
.chain(self.hidden_state.iter().cloned())
.collect();
let forget_gate = (0..self.hidden_size)
.map(|i| {
let sum: f32 = combined.iter()
.zip(self.forget_weights.iter())
.map(|(&x, weights)| x * weights.get(i).unwrap_or(&0.0))
.sum();
Self::sigmoid(sum)
})
.collect::<Vec<_>>();
let input_gate = (0..self.hidden_size)
.map(|i| {
let sum: f32 = combined.iter()
.zip(self.input_weights.iter())
.map(|(&x, weights)| x * weights.get(i).unwrap_or(&0.0))
.sum();
Self::sigmoid(sum)
})
.collect::<Vec<_>>();
let candidates = (0..self.hidden_size)
.map(|i| {
let sum: f32 = combined.iter()
.zip(self.candidate_weights.iter())
.map(|(&x, weights)| x * weights.get(i).unwrap_or(&0.0))
.sum();
Self::tanh(sum)
})
.collect::<Vec<_>>();
for i in 0..self.hidden_size {
self.cell_state[i] = forget_gate[i] * self.cell_state[i] + input_gate[i] * candidates[i];
}
let output_gate = (0..self.hidden_size)
.map(|i| {
let sum: f32 = combined.iter()
.zip(self.output_weights.iter())
.map(|(&x, weights)| x * weights.get(i).unwrap_or(&0.0))
.sum();
Self::sigmoid(sum)
})
.collect::<Vec<_>>();
for i in 0..self.hidden_size {
self.hidden_state[i] = output_gate[i] * Self::tanh(self.cell_state[i]);
}
self.hidden_state.clone()
}
fn predict_from_hidden(&self) -> Vec<u64> {
let mut scored_addresses: Vec<(u64, f32)> = self.reverse_vocab.iter()
.map(|(&vocab_id, &address)| {
let score = self.hidden_state.iter()
.enumerate()
.map(|(i, &h)| h * (vocab_id as f32 * 0.01 + i as f32))
.sum::<f32>();
(address, score.abs())
})
.collect();
scored_addresses.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
scored_addresses.truncate(self.max_predictions);
scored_addresses.into_iter().map(|(addr, _)| addr).collect()
}
}
impl<K> PrefetchStrategy<K> for LSTMPrefetch<K>
where
K: Copy + KeyConvert,
{
fn predict_next(&mut self, accessed_key: &K) -> Vec<K> {
if self.access_history.is_empty() {
let addr = accessed_key.to_u64();
return vec![K::from_u64(addr + 1), K::from_u64(addr + 2)];
}
if let Some(&last_addr) = self.access_history.back() {
let vocab_id = self.get_or_create_vocab_id(last_addr);
self.lstm_forward(vocab_id);
}
self.predict_from_hidden()
.into_iter()
.map(|addr| K::from_u64(addr))
.take(self.max_predictions)
.collect()
}
fn update_access_pattern(&mut self, key: &K) {
let addr = key.to_u64();
self.access_history.push_back(addr);
if self.access_history.len() > 100 {
self.access_history.pop_front();
}
}
fn reset(&mut self) {
self.cell_state.fill(0.0);
self.hidden_state.fill(0.0);
self.access_history.clear();
self.address_vocab.clear();
self.reverse_vocab.clear();
self.next_vocab_id = 0;
}
}
impl<K> Default for LSTMPrefetch<K>
where
K: Copy + KeyConvert,
{
fn default() -> Self {
Self::new()
}
}