fulgurance 0.4.1

A blazing-fast, adaptive prefetching and caching library for Rust
Documentation
//! Transformer-based Prefetch Strategy
//!
//! Uses a self-attention mechanism to capture complex memory
//! access dependencies, focusing on long-range irregular patterns.

use crate::PrefetchStrategy;
use crate::prefetch::neural::KeyConvert;
use std::collections::{HashMap, VecDeque};
use std::marker::PhantomData;

#[derive(Debug, Clone)]
pub struct TransformerPrefetch<K> {
    attention_heads: usize,
    head_dim: usize,
    sequence_length: usize,
    address_embeddings: HashMap<u64, Vec<f32>>,
    positional_embeddings: Vec<Vec<f32>>,
    #[allow(dead_code)]
    query_weights: Vec<Vec<f32>>,
    #[allow(dead_code)]
    key_weights: Vec<Vec<f32>>,
    #[allow(dead_code)]
    value_weights: Vec<Vec<f32>>,
    #[allow(dead_code)]
    output_weights: Vec<Vec<f32>>,
    access_sequence: VecDeque<u64>,
    max_vocab_size: usize,
    embedding_dim: usize,
    #[allow(dead_code)]
    learning_rate: f32,
    max_predictions: usize,
    _marker: PhantomData<K>,
}

impl<K> TransformerPrefetch<K> {
    pub fn new() -> Self {
        Self::with_config(8, 64, 16, 0.001, 3)
    }

    pub fn with_config(
        attention_heads: usize,
        embedding_dim: usize,
        sequence_length: usize,
        learning_rate: f32,
        max_predictions: usize,
    ) -> Self {
        let head_dim = embedding_dim / attention_heads;
        Self {
            attention_heads,
            head_dim,
            sequence_length,
            address_embeddings: HashMap::new(),
            positional_embeddings: (0..sequence_length)
                .map(|pos| Self::create_positional_encoding(pos, embedding_dim))
                .collect(),
            query_weights: Self::init_weights(embedding_dim, embedding_dim),
            key_weights: Self::init_weights(embedding_dim, embedding_dim),
            value_weights: Self::init_weights(embedding_dim, embedding_dim),
            output_weights: Self::init_weights(embedding_dim, embedding_dim),
            access_sequence: VecDeque::with_capacity(sequence_length),
            max_vocab_size: 10000,
            embedding_dim,
            learning_rate,
            max_predictions,
            _marker: PhantomData,
        }
    }

    fn init_weights(rows: usize, cols: usize) -> Vec<Vec<f32>> {
        (0..rows)
            .map(|i| (0..cols)

                .map(|j| ((i * cols + j) as f32 * 0.02) % 0.4 - 0.2)
                .collect())
            .collect()
    }

    fn create_positional_encoding(position: usize, embedding_dim: usize) -> Vec<f32> {
        (0..embedding_dim)
            .map(|i| {
                if i % 2 == 0 {
                    (position as f32 / 10000_f32.powf(i as f32 / embedding_dim as f32)).sin()
                } else {
                    (position as f32 / 10000_f32.powf((i-1) as f32 / embedding_dim as f32)).cos()
                }
            })
            .collect()
    }

    fn get_or_create_embedding(&mut self, address: u64) -> Vec<f32> {
        if let Some(embedding) = self.address_embeddings.get(&address) {
            embedding.clone()
        } else if self.address_embeddings.len() < self.max_vocab_size {
            let embedding: Vec<f32> = (0..self.embedding_dim)
                .map(|i| ((address as usize * self.embedding_dim + i) as f32 * 0.01) % 2.0 - 1.0)
                .collect();
            self.address_embeddings.insert(address, embedding.clone());
            embedding
        } else {
            vec![0.0; self.embedding_dim]
        }
    }

    fn self_attention(&self, embeddings: &[Vec<f32>]) -> Vec<Vec<f32>> {
        let seq_len = embeddings.len();
        let mut output = vec![vec![0.0; self.embedding_dim]; seq_len];
        
        for head in 0..self.attention_heads {
            let head_start = head * self.head_dim;
            let mut queries = vec![vec![0.0; self.head_dim]; seq_len];
            let mut keys = vec![vec![0.0; self.head_dim]; seq_len];
            let mut values = vec![vec![0.0; self.head_dim]; seq_len];
            
            // Create queries, keys, and values for this head
            for i in 0..seq_len {
                for j in 0..self.head_dim {
                    queries[i][j] = embeddings[i][head_start + j];
                    keys[i][j] = embeddings[i][head_start + j];
                    values[i][j] = embeddings[i][head_start + j];
                }
            }
            
            // Calculate attention weights
            let mut attention_weights = vec![vec![0.0; seq_len]; seq_len];
            for i in 0..seq_len {
                for j in 0..seq_len {
                    let score = queries[i].iter()
                        .zip(keys[j].iter())
                        .map(|(&q, &k)| q * k)
                        .sum::<f32>() / (self.head_dim as f32).sqrt();
                    attention_weights[i][j] = score;
                }
                
                // Apply softmax
                let max_score = attention_weights[i].iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
                let exp_scores: Vec<f32> = attention_weights[i].iter()
                    .map(|&x| (x - max_score).exp())
                    .collect();
                let sum_exp: f32 = exp_scores.iter().sum();
                if sum_exp > 0.0 {
                    for j in 0..seq_len {
                        attention_weights[i][j] = exp_scores[j] / sum_exp;
                    }
                }
            }
            
            // Apply attention to values
            for i in 0..seq_len {
                for d in 0..self.head_dim {
                    let attended_value = (0..seq_len)
                        .map(|j| attention_weights[i][j] * values[j][d])
                        .sum();
                    output[i][head_start + d] = attended_value;
                }
            }
        }
        
        output
    }

    fn predict_next_addresses(&mut self) -> Vec<u64> {
        if self.access_sequence.len() < 2 {
            return Vec::new();
        }
        
        let access_seq_clone = self.access_sequence.clone();
        let embeddings: Vec<Vec<f32>> = access_seq_clone.iter()
            .enumerate()
            .map(|(pos, &addr)| {
                let mut emb = self.get_or_create_embedding(addr);
                if pos < self.positional_embeddings.len() {
                    for (i, &pos_enc) in self.positional_embeddings[pos].iter().enumerate() {
                        if i < emb.len() {
                            emb[i] += pos_enc * 0.1;
                        }
                    }
                }
                emb
            })
            .collect();
        
        let attended = self.self_attention(&embeddings);
        
        if let Some(last_attended) = attended.last() {
            let mut scored_addresses: Vec<(u64, f32)> = self.address_embeddings.iter()
                .map(|(&addr, embedding)| {
                    let similarity = last_attended.iter()
                        .zip(embedding.iter())
                        .map(|(&a, &e)| a * e)
                        .sum::<f32>();
                    (addr, similarity)
                })
                .collect();
            
            // Remove recently accessed addresses
            let recent_set: std::collections::HashSet<u64> = self.access_sequence.iter().cloned().collect();
            scored_addresses.retain(|(addr, _)| !recent_set.contains(addr));
            
            scored_addresses.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
            scored_addresses.truncate(self.max_predictions);
            scored_addresses.into_iter().map(|(addr, _)| addr).collect()
        } else {
            Vec::new()
        }
    }
}

impl<K> PrefetchStrategy<K> for TransformerPrefetch<K>
where
    K: Copy + KeyConvert,
{
    fn predict_next(&mut self, _accessed_key: &K) -> Vec<K> {
        self.predict_next_addresses()
            .into_iter()
            .map(|addr| K::from_u64(addr))
            .take(self.max_predictions)
            .collect()
    }

    fn update_access_pattern(&mut self, key: &K) {
        let addr = key.to_u64();
        self.access_sequence.push_back(addr);
        if self.access_sequence.len() > self.sequence_length {
            self.access_sequence.pop_front();
        }
    }

    fn reset(&mut self) {
        self.access_sequence.clear();
        self.address_embeddings.clear();
    }
}

impl<K> Default for TransformerPrefetch<K>
where
    K: Copy + KeyConvert,
{
    fn default() -> Self {
        Self::new()
    }
}