use crate::PrefetchStrategy;
use crate::prefetch::neural::KeyConvert;
use std::collections::{HashMap, VecDeque};
use std::marker::PhantomData;
#[derive(Debug, Clone)]
pub struct TransformerPrefetch<K> {
attention_heads: usize,
head_dim: usize,
sequence_length: usize,
address_embeddings: HashMap<u64, Vec<f32>>,
positional_embeddings: Vec<Vec<f32>>,
#[allow(dead_code)]
query_weights: Vec<Vec<f32>>,
#[allow(dead_code)]
key_weights: Vec<Vec<f32>>,
#[allow(dead_code)]
value_weights: Vec<Vec<f32>>,
#[allow(dead_code)]
output_weights: Vec<Vec<f32>>,
access_sequence: VecDeque<u64>,
max_vocab_size: usize,
embedding_dim: usize,
#[allow(dead_code)]
learning_rate: f32,
max_predictions: usize,
_marker: PhantomData<K>,
}
impl<K> TransformerPrefetch<K> {
pub fn new() -> Self {
Self::with_config(8, 64, 16, 0.001, 3)
}
pub fn with_config(
attention_heads: usize,
embedding_dim: usize,
sequence_length: usize,
learning_rate: f32,
max_predictions: usize,
) -> Self {
let head_dim = embedding_dim / attention_heads;
Self {
attention_heads,
head_dim,
sequence_length,
address_embeddings: HashMap::new(),
positional_embeddings: (0..sequence_length)
.map(|pos| Self::create_positional_encoding(pos, embedding_dim))
.collect(),
query_weights: Self::init_weights(embedding_dim, embedding_dim),
key_weights: Self::init_weights(embedding_dim, embedding_dim),
value_weights: Self::init_weights(embedding_dim, embedding_dim),
output_weights: Self::init_weights(embedding_dim, embedding_dim),
access_sequence: VecDeque::with_capacity(sequence_length),
max_vocab_size: 10000,
embedding_dim,
learning_rate,
max_predictions,
_marker: PhantomData,
}
}
fn init_weights(rows: usize, cols: usize) -> Vec<Vec<f32>> {
(0..rows)
.map(|i| (0..cols)
.map(|j| ((i * cols + j) as f32 * 0.02) % 0.4 - 0.2)
.collect())
.collect()
}
fn create_positional_encoding(position: usize, embedding_dim: usize) -> Vec<f32> {
(0..embedding_dim)
.map(|i| {
if i % 2 == 0 {
(position as f32 / 10000_f32.powf(i as f32 / embedding_dim as f32)).sin()
} else {
(position as f32 / 10000_f32.powf((i-1) as f32 / embedding_dim as f32)).cos()
}
})
.collect()
}
fn get_or_create_embedding(&mut self, address: u64) -> Vec<f32> {
if let Some(embedding) = self.address_embeddings.get(&address) {
embedding.clone()
} else if self.address_embeddings.len() < self.max_vocab_size {
let embedding: Vec<f32> = (0..self.embedding_dim)
.map(|i| ((address as usize * self.embedding_dim + i) as f32 * 0.01) % 2.0 - 1.0)
.collect();
self.address_embeddings.insert(address, embedding.clone());
embedding
} else {
vec![0.0; self.embedding_dim]
}
}
fn self_attention(&self, embeddings: &[Vec<f32>]) -> Vec<Vec<f32>> {
let seq_len = embeddings.len();
let mut output = vec![vec![0.0; self.embedding_dim]; seq_len];
for head in 0..self.attention_heads {
let head_start = head * self.head_dim;
let mut queries = vec![vec![0.0; self.head_dim]; seq_len];
let mut keys = vec![vec![0.0; self.head_dim]; seq_len];
let mut values = vec![vec![0.0; self.head_dim]; seq_len];
for i in 0..seq_len {
for j in 0..self.head_dim {
queries[i][j] = embeddings[i][head_start + j];
keys[i][j] = embeddings[i][head_start + j];
values[i][j] = embeddings[i][head_start + j];
}
}
let mut attention_weights = vec![vec![0.0; seq_len]; seq_len];
for i in 0..seq_len {
for j in 0..seq_len {
let score = queries[i].iter()
.zip(keys[j].iter())
.map(|(&q, &k)| q * k)
.sum::<f32>() / (self.head_dim as f32).sqrt();
attention_weights[i][j] = score;
}
let max_score = attention_weights[i].iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let exp_scores: Vec<f32> = attention_weights[i].iter()
.map(|&x| (x - max_score).exp())
.collect();
let sum_exp: f32 = exp_scores.iter().sum();
if sum_exp > 0.0 {
for j in 0..seq_len {
attention_weights[i][j] = exp_scores[j] / sum_exp;
}
}
}
for i in 0..seq_len {
for d in 0..self.head_dim {
let attended_value = (0..seq_len)
.map(|j| attention_weights[i][j] * values[j][d])
.sum();
output[i][head_start + d] = attended_value;
}
}
}
output
}
fn predict_next_addresses(&mut self) -> Vec<u64> {
if self.access_sequence.len() < 2 {
return Vec::new();
}
let access_seq_clone = self.access_sequence.clone();
let embeddings: Vec<Vec<f32>> = access_seq_clone.iter()
.enumerate()
.map(|(pos, &addr)| {
let mut emb = self.get_or_create_embedding(addr);
if pos < self.positional_embeddings.len() {
for (i, &pos_enc) in self.positional_embeddings[pos].iter().enumerate() {
if i < emb.len() {
emb[i] += pos_enc * 0.1;
}
}
}
emb
})
.collect();
let attended = self.self_attention(&embeddings);
if let Some(last_attended) = attended.last() {
let mut scored_addresses: Vec<(u64, f32)> = self.address_embeddings.iter()
.map(|(&addr, embedding)| {
let similarity = last_attended.iter()
.zip(embedding.iter())
.map(|(&a, &e)| a * e)
.sum::<f32>();
(addr, similarity)
})
.collect();
let recent_set: std::collections::HashSet<u64> = self.access_sequence.iter().cloned().collect();
scored_addresses.retain(|(addr, _)| !recent_set.contains(addr));
scored_addresses.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
scored_addresses.truncate(self.max_predictions);
scored_addresses.into_iter().map(|(addr, _)| addr).collect()
} else {
Vec::new()
}
}
}
impl<K> PrefetchStrategy<K> for TransformerPrefetch<K>
where
K: Copy + KeyConvert,
{
fn predict_next(&mut self, _accessed_key: &K) -> Vec<K> {
self.predict_next_addresses()
.into_iter()
.map(|addr| K::from_u64(addr))
.take(self.max_predictions)
.collect()
}
fn update_access_pattern(&mut self, key: &K) {
let addr = key.to_u64();
self.access_sequence.push_back(addr);
if self.access_sequence.len() > self.sequence_length {
self.access_sequence.pop_front();
}
}
fn reset(&mut self) {
self.access_sequence.clear();
self.address_embeddings.clear();
}
}
impl<K> Default for TransformerPrefetch<K>
where
K: Copy + KeyConvert,
{
fn default() -> Self {
Self::new()
}
}