breve 0.1.0

In-memory cache implementation with Uno as the admission policy and S3-FIFO as the eviction policy
Documentation
// Copyright 2025 Chojan Shang.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::hash::Hash;
use std::sync::atomic::{AtomicUsize, Ordering};

use crate::atomic::AtomicF32;
use crate::estimator::Estimator;

// bare-minimum UnoSketch, based on Count-Min Sketch, to estimate freq and reuse distance in a limited window.
pub(crate) struct UnoSketch {
    freq_estimator: Estimator,
    reuse_estimator: Estimator,
    window_counter: AtomicUsize,
    cache_size: usize,
    window_limit: usize,
}

impl UnoSketch {
    pub fn get<T: Hash + Clone>(&self, key: T) -> (u8, usize) {
        let freq = self.freq_estimator.get(key.clone());
        let reuse_freq = self.reuse_estimator.get(key);
        let reuse = self.reuse_estimate(reuse_freq);
        (freq, reuse)
    }

    pub fn tick(&self) {
        let window_size = self.window_counter.fetch_add(1, Ordering::SeqCst);
        if window_size % (self.cache_size * 2) == 0 {
            self.freq_estimator.exponential_decay(0.8);
        }
        // When window_size concurrently increases, only one resets the window and age the estimator.
        // > self.window_limit * 2 is a safety net in case for whatever reason window_size grows
        // out of control
        if window_size == self.window_limit || window_size > self.window_limit * 2 {
            self.window_counter.store(0, Ordering::SeqCst);
            self.reuse_estimator.age(1); // right shift, next window
        }
    }

    pub fn incr<T: Hash + Clone>(&self, key: T) -> (u8, usize) {
        self.tick();
        let freq = self.freq_estimator.incr(key.clone());
        let reuse_freq = self.reuse_estimator.incr(key);
        let reuse = self.reuse_estimate(reuse_freq);
        (freq, reuse)
    }

    pub fn incr_freq_only<T: Hash + Clone>(&self, key: T) -> (u8, usize) {
        let freq = self.freq_estimator.incr(key.clone());
        let reuse_freq = self.reuse_estimator.get(key);
        let reuse = self.reuse_estimate(reuse_freq);
        (freq, reuse)
    }

    pub fn incr_reuse_only<T: Hash + Clone>(&self, key: T) -> (u8, usize) {
        let freq = self.freq_estimator.get(key.clone());
        let reuse_freq = self.reuse_estimator.incr(key);
        let reuse = self.reuse_estimate(reuse_freq);
        (freq, reuse)
    }

    // because we use 8-bits counters, window size can be 256 * the cache size
    pub fn new(cache_size: usize) -> Self {
        Self {
            freq_estimator: Estimator::optimal(cache_size),
            reuse_estimator: Estimator::optimal(cache_size),
            window_counter: Default::default(),
            // 8x: just a heuristic to balance the memory usage and accuracy
            window_limit: cache_size * 8,
            cache_size,
        }
    }

    pub fn new_compact(cache_size: usize) -> Self {
        Self {
            freq_estimator: Estimator::compact(cache_size),
            reuse_estimator: Estimator::compact(cache_size),
            window_counter: Default::default(),
            // 8x: just a heuristic to balance the memory usage and accuracy
            window_limit: cache_size * 8,
            cache_size,
        }
    }

    fn reuse_estimate(&self, freq: u8) -> usize {
        let window_value = self.window_counter.load(Ordering::Relaxed);
        if freq > 0 {
            ((self.window_limit + window_value) as f32) as usize / freq as usize
        } else {
            self.window_limit + window_value
        }
    }
}

pub struct UnoLearner {
    weight_distance: AtomicF32,
    weight_freq: AtomicF32,
    bias: AtomicF32,
    learning_rate: f32,
    pub scale_factor: usize,
}

impl UnoLearner {
    pub fn new(learning_rate: f32, cache_size: usize) -> UnoLearner {
        UnoLearner {
            weight_distance: AtomicF32::new(-0.2),
            weight_freq: AtomicF32::new(0.8),
            bias: AtomicF32::new(0.5),
            learning_rate,
            scale_factor: cache_size,
        }
    }

    // true for freq, false for distance
    pub fn weight_update(&self, freq: u8, distance: usize, touch: bool) {
        let update_value = if touch {
            self.learning_rate
        } else {
            -self.learning_rate
        };

        let distance_scale = (distance / (self.scale_factor + 1)) as f32;
        let weight_distance =
            self.weight_distance.load(Ordering::Relaxed) + update_value * distance_scale;
        let weight_freq = self.weight_freq.load(Ordering::Relaxed) + update_value * freq as f32;
        let bias = self.bias.load(Ordering::Relaxed) + update_value;
        self.weight_distance
            .store(weight_distance, Ordering::Relaxed);
        self.weight_freq.store(weight_freq, Ordering::Relaxed);
        self.bias.store(bias, Ordering::Relaxed);
    }

    pub fn predict(&self, freq: u8, distance: usize) -> f32 {
        let bias = self.bias.load(Ordering::SeqCst);
        let weight_freq = self.weight_freq.load(Ordering::SeqCst);
        let weight_distance = self.weight_distance.load(Ordering::SeqCst);
        let distance_scale = (distance / (self.scale_factor + 1)) as f32;

        bias + weight_freq * freq as f32 + weight_distance * distance_scale
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_uno() {
        let uno = UnoSketch::new(1);
        assert_eq!(uno.get(1), (0, 8));
        assert_eq!(uno.incr(1), (1, 9));
        assert_eq!(uno.incr(1), (2, 5));
        assert_eq!(uno.get(1), (2, 5));

        assert_eq!(uno.get(2), (0, 10));
        assert_eq!(uno.incr(2), (1, 11));
        assert_eq!(uno.incr(2), (2, 6));
        assert_eq!(uno.get(2), (2, 6));

        assert_eq!(uno.incr(3), (1, 13));
        assert_eq!(uno.incr(3), (2, 7));
        assert_eq!(uno.incr(3), (3, 5));
        assert_eq!(uno.incr(3), (4, 4));

        // 8 incr(), now reset

        assert_eq!(uno.incr(3), (4, 2));
        assert_eq!(uno.incr(1), (3, 4));
        assert_eq!(uno.incr(2), (3, 5));
        assert_eq!(uno.get(3), (3, 3));
    }
}