fulgurance 0.4.0

A blazing-fast, adaptive prefetching and caching library for Rust
Documentation
use crate::PrefetchStrategy;
use std::collections::VecDeque;
use std::marker::PhantomData;

/// Trait to convert keys to/from u64 for address normalization in the neural prefetcher.
pub trait KeyConvert {
    fn to_u64(&self) -> u64;
    fn from_u64(addr: u64) -> Self;
}

impl KeyConvert for i32 {
    fn to_u64(&self) -> u64 { *self as u64 }
    fn from_u64(addr: u64) -> Self { addr as i32 }
}

impl KeyConvert for i64 {
    fn to_u64(&self) -> u64 { *self as u64 }
    fn from_u64(addr: u64) -> Self { addr as i64 }
}

impl KeyConvert for usize {
    fn to_u64(&self) -> u64 { *self as u64 }
    fn from_u64(addr: u64) -> Self { addr as usize }
}

/// Neural prefetcher generic over key type `K`.
#[derive(Debug, Clone)]
pub struct NeuralPrefetch<K> {
    input_weights: Vec<Vec<f32>>,
    hidden_weights: Vec<Vec<f32>>,
    input_bias: Vec<f32>,
    hidden_bias: Vec<f32>,
    access_history: VecDeque<u64>,
    sequence_length: usize,
    hidden_size: usize,
    vocab_size: usize,
    learning_rate: f32,
    min_address: u64,
    max_address: u64,
    confidence_threshold: f32,
    max_predictions: usize,
    training_steps: usize,
    _marker: PhantomData<K>,
}

impl<K> NeuralPrefetch<K> {
    /// Creates a new instance with default configuration.
    pub fn new() -> Self {
        Self::with_config(8, 32, 1000, 0.01, 0.5, 3)
    }

    /// Creates a new instance with specified configuration parameters.
    pub fn with_config(
        sequence_length: usize,
        hidden_size: usize,
        vocab_size: usize,
        learning_rate: f32,
        confidence_threshold: f32,
        max_predictions: usize,
    ) -> Self {
        let input_weights = Self::init_weights(sequence_length, hidden_size);
        let hidden_weights = Self::init_weights(hidden_size, vocab_size);
        Self {
            input_weights,
            hidden_weights,
            input_bias: vec![0.0; hidden_size],
            hidden_bias: vec![0.0; vocab_size],
            access_history: VecDeque::with_capacity(sequence_length),
            sequence_length,
            hidden_size,
            vocab_size,
            learning_rate,
            min_address: u64::MAX,
            max_address: 0,
            confidence_threshold,
            max_predictions,
            training_steps: 0,
            _marker: PhantomData,
        }
    }

    fn init_weights(rows: usize, cols: usize) -> Vec<Vec<f32>> {
        (0..rows)
            .map(|_| (0..cols).map(|_| rand::random::<f32>() * 0.1 - 0.05).collect())
            .collect()
    }

    fn address_to_vocab(&self, address: u64) -> usize {
        if self.max_address <= self.min_address {
            0
        } else {
            let norm = (address - self.min_address) as f64 / (self.max_address - self.min_address) as f64;
            (norm * (self.vocab_size as f64 - 1.0)) as usize
        }
    }

    fn vocab_to_address(&self, vocab_idx: usize) -> u64 {
        if self.max_address <= self.min_address {
            self.min_address
        } else {
            let norm = vocab_idx as f64 / (self.vocab_size as f64 - 1.0);
            self.min_address + ((self.max_address - self.min_address) as f64 * norm) as u64
        }
    }

    fn relu(x: f32) -> f32 {
        x.max(0.0)
    }

    fn softmax(logits: &[f32]) -> Vec<f32> {
        let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
        let exps: Vec<f32> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
        let sum: f32 = exps.iter().sum();
        exps.iter().map(|&x| x / sum).collect()
    }

    fn forward(&self, input_seq: &[usize]) -> Vec<f32> {
        if input_seq.len() != self.sequence_length {
            return vec![0.0; self.vocab_size];
        }
        let mut hidden = vec![0.0; self.hidden_size];
        for h in 0..self.hidden_size {
            let mut sum = self.input_bias[h];
            for i in 0..self.sequence_length {
                if input_seq[i] < self.vocab_size {
                    sum += self.input_weights[i][h];
                }
            }
            hidden[h] = Self::relu(sum);
        }
        let mut output = vec![0.0; self.vocab_size];
        for o in 0..self.vocab_size {
            let mut sum = self.hidden_bias[o];
            for h in 0..self.hidden_size {
                sum += hidden[h] * self.hidden_weights[h][o];
            }
            output[o] = sum;
        }
        Self::softmax(&output)
    }

    fn backward_update(&mut self, input_seq: &[usize], target: usize) {
        if input_seq.len() != self.sequence_length || target >= self.vocab_size {
            return;
        }
        let mut predictions = self.forward(input_seq);
        predictions[target] -= 1.0; // gradient of cross entropy loss

        for h in 0..self.hidden_size {
            for o in 0..self.vocab_size {
                let grad = predictions[o] * self.learning_rate;
                self.hidden_weights[h][o] -= grad;
            }
        }
        for o in 0..self.vocab_size {
            self.hidden_bias[o] -= predictions[o] * self.learning_rate;
        }
        for i in 0..self.sequence_length {
            for h in 0..self.hidden_size {
                self.input_weights[i][h] -=
                    self.learning_rate * predictions.iter().sum::<f32>() * 0.01; // simplified
            }
        }
        self.training_steps += 1;
    }

    fn update_range(&mut self, addr: u64) {
        if addr < self.min_address {
            self.min_address = addr;
        }
        if addr > self.max_address {
            self.max_address = addr;
        }
    }

    fn get_top_predictions(&self, probs: &[f32]) -> Vec<(u64, f32)> {
        let mut indexed: Vec<(usize, f32)> = probs
            .iter()
            .enumerate()
            .filter(|&(_, &prob)| prob >= self.confidence_threshold)
            .map(|(idx, &prob)| (idx, prob))
            .collect();
        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
        indexed.truncate(self.max_predictions);
        indexed
            .into_iter()
            .map(|(idx, prob)| (self.vocab_to_address(idx), prob))
            .collect()
    }
}

impl<K> PrefetchStrategy<K> for NeuralPrefetch<K>
where
    K: Copy + KeyConvert,
{
    fn predict_next(&mut self, key: &K) -> Vec<K> {
        let addr = key.to_u64();
        if self.access_history.len() < self.sequence_length {
            return (1..=3).map(|i| K::from_u64(addr + i)).collect();
        }
        let input_seq: Vec<usize> = self.access_history.iter().map(|&a| self.address_to_vocab(a)).collect();
        let probs = self.forward(&input_seq);
        self.get_top_predictions(&probs)
            .into_iter()
            .map(|(addr, _)| K::from_u64(addr))
            .collect()
    }

    fn update_access_pattern(&mut self, key: &K) {
        let addr = key.to_u64();
        self.update_range(addr);
        if self.access_history.len() >= self.sequence_length {
            let input_seq: Vec<usize> = self.access_history.iter()
                .take(self.sequence_length - 1)
                .map(|&a| self.address_to_vocab(a))
                .collect();
            let target = self.address_to_vocab(addr);
            self.backward_update(&input_seq, target);
        }
        self.access_history.push_back(addr);
        if self.access_history.len() > self.sequence_length {
            self.access_history.pop_front();
        }
    }

    fn reset(&mut self) {
        self.access_history.clear();
        self.min_address = u64::MAX;
        self.max_address = 0;
        self.training_steps = 0;
        self.input_weights = Self::init_weights(self.sequence_length, self.hidden_size);
        self.hidden_weights = Self::init_weights(self.hidden_size, self.vocab_size);
        self.input_bias.fill(0.0);
        self.hidden_bias.fill(0.0);
    }
}

impl<K> Default for NeuralPrefetch<K>
where
    K: Copy + KeyConvert,
{
    fn default() -> Self {
        Self::new()
    }
}