use crate::PrefetchStrategy;
use std::collections::VecDeque;
use std::marker::PhantomData;
pub trait KeyConvert {
fn to_u64(&self) -> u64;
fn from_u64(addr: u64) -> Self;
}
impl KeyConvert for i32 {
fn to_u64(&self) -> u64 { *self as u64 }
fn from_u64(addr: u64) -> Self { addr as i32 }
}
impl KeyConvert for i64 {
fn to_u64(&self) -> u64 { *self as u64 }
fn from_u64(addr: u64) -> Self { addr as i64 }
}
impl KeyConvert for usize {
fn to_u64(&self) -> u64 { *self as u64 }
fn from_u64(addr: u64) -> Self { addr as usize }
}
#[derive(Debug, Clone)]
pub struct NeuralPrefetch<K> {
input_weights: Vec<Vec<f32>>,
hidden_weights: Vec<Vec<f32>>,
input_bias: Vec<f32>,
hidden_bias: Vec<f32>,
access_history: VecDeque<u64>,
sequence_length: usize,
hidden_size: usize,
vocab_size: usize,
learning_rate: f32,
min_address: u64,
max_address: u64,
confidence_threshold: f32,
max_predictions: usize,
training_steps: usize,
_marker: PhantomData<K>,
}
impl<K> NeuralPrefetch<K> {
pub fn new() -> Self {
Self::with_config(8, 32, 1000, 0.01, 0.5, 3)
}
pub fn with_config(
sequence_length: usize,
hidden_size: usize,
vocab_size: usize,
learning_rate: f32,
confidence_threshold: f32,
max_predictions: usize,
) -> Self {
let input_weights = Self::init_weights(sequence_length, hidden_size);
let hidden_weights = Self::init_weights(hidden_size, vocab_size);
Self {
input_weights,
hidden_weights,
input_bias: vec![0.0; hidden_size],
hidden_bias: vec![0.0; vocab_size],
access_history: VecDeque::with_capacity(sequence_length),
sequence_length,
hidden_size,
vocab_size,
learning_rate,
min_address: u64::MAX,
max_address: 0,
confidence_threshold,
max_predictions,
training_steps: 0,
_marker: PhantomData,
}
}
fn init_weights(rows: usize, cols: usize) -> Vec<Vec<f32>> {
(0..rows)
.map(|_| (0..cols).map(|_| rand::random::<f32>() * 0.1 - 0.05).collect())
.collect()
}
fn address_to_vocab(&self, address: u64) -> usize {
if self.max_address <= self.min_address {
0
} else {
let norm = (address - self.min_address) as f64 / (self.max_address - self.min_address) as f64;
(norm * (self.vocab_size as f64 - 1.0)) as usize
}
}
fn vocab_to_address(&self, vocab_idx: usize) -> u64 {
if self.max_address <= self.min_address {
self.min_address
} else {
let norm = vocab_idx as f64 / (self.vocab_size as f64 - 1.0);
self.min_address + ((self.max_address - self.min_address) as f64 * norm) as u64
}
}
fn relu(x: f32) -> f32 {
x.max(0.0)
}
fn softmax(logits: &[f32]) -> Vec<f32> {
let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
let sum: f32 = exps.iter().sum();
exps.iter().map(|&x| x / sum).collect()
}
fn forward(&self, input_seq: &[usize]) -> Vec<f32> {
if input_seq.len() != self.sequence_length {
return vec![0.0; self.vocab_size];
}
let mut hidden = vec![0.0; self.hidden_size];
for h in 0..self.hidden_size {
let mut sum = self.input_bias[h];
for i in 0..self.sequence_length {
if input_seq[i] < self.vocab_size {
sum += self.input_weights[i][h];
}
}
hidden[h] = Self::relu(sum);
}
let mut output = vec![0.0; self.vocab_size];
for o in 0..self.vocab_size {
let mut sum = self.hidden_bias[o];
for h in 0..self.hidden_size {
sum += hidden[h] * self.hidden_weights[h][o];
}
output[o] = sum;
}
Self::softmax(&output)
}
fn backward_update(&mut self, input_seq: &[usize], target: usize) {
if input_seq.len() != self.sequence_length || target >= self.vocab_size {
return;
}
let mut predictions = self.forward(input_seq);
predictions[target] -= 1.0;
for h in 0..self.hidden_size {
for o in 0..self.vocab_size {
let grad = predictions[o] * self.learning_rate;
self.hidden_weights[h][o] -= grad;
}
}
for o in 0..self.vocab_size {
self.hidden_bias[o] -= predictions[o] * self.learning_rate;
}
for i in 0..self.sequence_length {
for h in 0..self.hidden_size {
self.input_weights[i][h] -=
self.learning_rate * predictions.iter().sum::<f32>() * 0.01; }
}
self.training_steps += 1;
}
fn update_range(&mut self, addr: u64) {
if addr < self.min_address {
self.min_address = addr;
}
if addr > self.max_address {
self.max_address = addr;
}
}
fn get_top_predictions(&self, probs: &[f32]) -> Vec<(u64, f32)> {
let mut indexed: Vec<(usize, f32)> = probs
.iter()
.enumerate()
.filter(|&(_, &prob)| prob >= self.confidence_threshold)
.map(|(idx, &prob)| (idx, prob))
.collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
indexed.truncate(self.max_predictions);
indexed
.into_iter()
.map(|(idx, prob)| (self.vocab_to_address(idx), prob))
.collect()
}
}
impl<K> PrefetchStrategy<K> for NeuralPrefetch<K>
where
K: Copy + KeyConvert,
{
fn predict_next(&mut self, key: &K) -> Vec<K> {
let addr = key.to_u64();
if self.access_history.len() < self.sequence_length {
return (1..=3).map(|i| K::from_u64(addr + i)).collect();
}
let input_seq: Vec<usize> = self.access_history.iter().map(|&a| self.address_to_vocab(a)).collect();
let probs = self.forward(&input_seq);
self.get_top_predictions(&probs)
.into_iter()
.map(|(addr, _)| K::from_u64(addr))
.collect()
}
fn update_access_pattern(&mut self, key: &K) {
let addr = key.to_u64();
self.update_range(addr);
if self.access_history.len() >= self.sequence_length {
let input_seq: Vec<usize> = self.access_history.iter()
.take(self.sequence_length - 1)
.map(|&a| self.address_to_vocab(a))
.collect();
let target = self.address_to_vocab(addr);
self.backward_update(&input_seq, target);
}
self.access_history.push_back(addr);
if self.access_history.len() > self.sequence_length {
self.access_history.pop_front();
}
}
fn reset(&mut self) {
self.access_history.clear();
self.min_address = u64::MAX;
self.max_address = 0;
self.training_steps = 0;
self.input_weights = Self::init_weights(self.sequence_length, self.hidden_size);
self.hidden_weights = Self::init_weights(self.hidden_size, self.vocab_size);
self.input_bias.fill(0.0);
self.hidden_bias.fill(0.0);
}
}
impl<K> Default for NeuralPrefetch<K>
where
K: Copy + KeyConvert,
{
fn default() -> Self {
Self::new()
}
}