use std::hash::Hash;
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::atomic::AtomicF32;
use crate::estimator::Estimator;
pub(crate) struct UnoSketch {
freq_estimator: Estimator,
reuse_estimator: Estimator,
window_counter: AtomicUsize,
cache_size: usize,
window_limit: usize,
}
impl UnoSketch {
pub fn get<T: Hash + Clone>(&self, key: T) -> (u8, usize) {
let freq = self.freq_estimator.get(key.clone());
let reuse_freq = self.reuse_estimator.get(key);
let reuse = self.reuse_estimate(reuse_freq);
(freq, reuse)
}
pub fn tick(&self) {
let window_size = self.window_counter.fetch_add(1, Ordering::SeqCst);
if window_size % (self.cache_size * 2) == 0 {
self.freq_estimator.exponential_decay(0.8);
}
if window_size == self.window_limit || window_size > self.window_limit * 2 {
self.window_counter.store(0, Ordering::SeqCst);
self.reuse_estimator.age(1); }
}
pub fn incr<T: Hash + Clone>(&self, key: T) -> (u8, usize) {
self.tick();
let freq = self.freq_estimator.incr(key.clone());
let reuse_freq = self.reuse_estimator.incr(key);
let reuse = self.reuse_estimate(reuse_freq);
(freq, reuse)
}
pub fn incr_freq_only<T: Hash + Clone>(&self, key: T) -> (u8, usize) {
let freq = self.freq_estimator.incr(key.clone());
let reuse_freq = self.reuse_estimator.get(key);
let reuse = self.reuse_estimate(reuse_freq);
(freq, reuse)
}
pub fn incr_reuse_only<T: Hash + Clone>(&self, key: T) -> (u8, usize) {
let freq = self.freq_estimator.get(key.clone());
let reuse_freq = self.reuse_estimator.incr(key);
let reuse = self.reuse_estimate(reuse_freq);
(freq, reuse)
}
pub fn new(cache_size: usize) -> Self {
Self {
freq_estimator: Estimator::optimal(cache_size),
reuse_estimator: Estimator::optimal(cache_size),
window_counter: Default::default(),
window_limit: cache_size * 8,
cache_size,
}
}
pub fn new_compact(cache_size: usize) -> Self {
Self {
freq_estimator: Estimator::compact(cache_size),
reuse_estimator: Estimator::compact(cache_size),
window_counter: Default::default(),
window_limit: cache_size * 8,
cache_size,
}
}
fn reuse_estimate(&self, freq: u8) -> usize {
let window_value = self.window_counter.load(Ordering::Relaxed);
if freq > 0 {
((self.window_limit + window_value) as f32) as usize / freq as usize
} else {
self.window_limit + window_value
}
}
}
pub struct UnoLearner {
weight_distance: AtomicF32,
weight_freq: AtomicF32,
bias: AtomicF32,
learning_rate: f32,
pub scale_factor: usize,
}
impl UnoLearner {
pub fn new(learning_rate: f32, cache_size: usize) -> UnoLearner {
UnoLearner {
weight_distance: AtomicF32::new(-0.2),
weight_freq: AtomicF32::new(0.8),
bias: AtomicF32::new(0.5),
learning_rate,
scale_factor: cache_size,
}
}
pub fn weight_update(&self, freq: u8, distance: usize, touch: bool) {
let update_value = if touch {
self.learning_rate
} else {
-self.learning_rate
};
let distance_scale = (distance / (self.scale_factor + 1)) as f32;
let weight_distance =
self.weight_distance.load(Ordering::Relaxed) + update_value * distance_scale;
let weight_freq = self.weight_freq.load(Ordering::Relaxed) + update_value * freq as f32;
let bias = self.bias.load(Ordering::Relaxed) + update_value;
self.weight_distance
.store(weight_distance, Ordering::Relaxed);
self.weight_freq.store(weight_freq, Ordering::Relaxed);
self.bias.store(bias, Ordering::Relaxed);
}
pub fn predict(&self, freq: u8, distance: usize) -> f32 {
let bias = self.bias.load(Ordering::SeqCst);
let weight_freq = self.weight_freq.load(Ordering::SeqCst);
let weight_distance = self.weight_distance.load(Ordering::SeqCst);
let distance_scale = (distance / (self.scale_factor + 1)) as f32;
bias + weight_freq * freq as f32 + weight_distance * distance_scale
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_uno() {
let uno = UnoSketch::new(1);
assert_eq!(uno.get(1), (0, 8));
assert_eq!(uno.incr(1), (1, 9));
assert_eq!(uno.incr(1), (2, 5));
assert_eq!(uno.get(1), (2, 5));
assert_eq!(uno.get(2), (0, 10));
assert_eq!(uno.incr(2), (1, 11));
assert_eq!(uno.incr(2), (2, 6));
assert_eq!(uno.get(2), (2, 6));
assert_eq!(uno.incr(3), (1, 13));
assert_eq!(uno.incr(3), (2, 7));
assert_eq!(uno.incr(3), (3, 5));
assert_eq!(uno.incr(3), (4, 4));
assert_eq!(uno.incr(3), (4, 2));
assert_eq!(uno.incr(1), (3, 4));
assert_eq!(uno.incr(2), (3, 5));
assert_eq!(uno.get(3), (3, 3));
}
}