pub type Qfloat = f32;
const NONE: usize = usize::MAX;
struct LruNode {
prev: usize,
next: usize,
}
pub struct Cache {
l: usize,
size: usize,
data: Vec<Option<Vec<Qfloat>>>,
len: Vec<usize>,
nodes: Vec<LruNode>,
}
impl Cache {
pub fn new(l: usize, size_bytes: usize) -> Self {
let mut size = size_bytes / std::mem::size_of::<Qfloat>();
let header_size = l * std::mem::size_of::<LruNode>() / std::mem::size_of::<Qfloat>();
size = size.max(2 * l + header_size).saturating_sub(header_size);
let mut nodes: Vec<LruNode> = (0..l)
.map(|_| LruNode {
prev: NONE,
next: NONE,
})
.collect();
nodes.push(LruNode { prev: l, next: l });
Cache {
l,
size,
data: (0..l).map(|_| None).collect(),
len: vec![0; l],
nodes,
}
}
#[inline]
fn lru_delete(&mut self, i: usize) {
let prev = self.nodes[i].prev;
let next = self.nodes[i].next;
self.nodes[prev].next = next;
self.nodes[next].prev = prev;
self.nodes[i].prev = NONE;
self.nodes[i].next = NONE;
}
#[inline]
fn lru_insert(&mut self, i: usize) {
let head = self.l; let tail = self.nodes[head].prev;
self.nodes[i].next = head;
self.nodes[i].prev = tail;
self.nodes[tail].next = i;
self.nodes[head].prev = i;
}
#[inline]
fn in_lru(&self, i: usize) -> bool {
self.nodes[i].prev != NONE
}
pub fn get_data(&mut self, index: usize, request_len: usize) -> (&mut [Qfloat], usize) {
assert!(index < self.l);
if self.in_lru(index) {
self.lru_delete(index);
}
let old_len = self.len[index];
let more = request_len.saturating_sub(old_len);
if more > 0 {
let head = self.l;
while self.size < more {
let victim = self.nodes[head].next;
if victim == head {
break; }
self.lru_delete(victim);
self.size += self.len[victim];
self.data[victim] = None;
self.len[victim] = 0;
}
let entry = self.data[index].get_or_insert_with(Vec::new);
entry.resize(request_len, 0.0);
self.size -= more;
self.len[index] = request_len;
}
self.lru_insert(index);
let start = old_len;
let data = self.data[index].get_or_insert_with(Vec::new);
(data.as_mut_slice(), start)
}
pub fn swap_index(&mut self, i: usize, j: usize) {
if i == j {
return;
}
let i_in = self.in_lru(i);
let j_in = self.in_lru(j);
if i_in {
self.lru_delete(i);
}
if j_in {
self.lru_delete(j);
}
self.data.swap(i, j);
self.len.swap(i, j);
if i_in {
self.lru_insert(j);
}
if j_in {
self.lru_insert(i);
}
let (lo, hi) = if i < j { (i, j) } else { (j, i) };
let head = self.l;
let mut h = self.nodes[head].next;
while h != head {
let next = self.nodes[h].next;
if self.len[h] > lo {
if self.len[h] > hi {
if let Some(ref mut row) = self.data[h] {
row.swap(lo, hi);
}
} else {
self.lru_delete(h);
self.size += self.len[h];
self.data[h] = None;
self.len[h] = 0;
}
}
h = next;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_get_and_fill() {
let mut cache = Cache::new(3, 100);
let (data, start) = cache.get_data(0, 3);
assert_eq!(start, 0);
assert_eq!(data.len(), 3);
data[0] = 1.0;
data[1] = 2.0;
data[2] = 3.0;
let (data, start) = cache.get_data(0, 3);
assert_eq!(start, 3);
assert_eq!(data[0], 1.0);
assert_eq!(data[1], 2.0);
assert_eq!(data[2], 3.0);
}
#[test]
fn extend_cached_row() {
let mut cache = Cache::new(3, 1000);
let (data, start) = cache.get_data(0, 2);
assert_eq!(start, 0);
data[0] = 10.0;
data[1] = 20.0;
let (data, start) = cache.get_data(0, 3);
assert_eq!(start, 2);
assert_eq!(data[0], 10.0);
assert_eq!(data[1], 20.0);
data[2] = 30.0;
}
#[test]
fn lru_eviction() {
let l = 10;
let bytes = (2 * l + l * 3) * std::mem::size_of::<Qfloat>();
let mut cache = Cache::new(l, bytes);
let (data, start) = cache.get_data(0, l);
assert_eq!(start, 0);
data[0] = 1.0;
let (data, start) = cache.get_data(1, l);
assert_eq!(start, 0);
data[0] = 3.0;
let (data, start) = cache.get_data(2, l);
assert_eq!(start, 0);
data[0] = 5.0;
let (_, start) = cache.get_data(0, l);
assert_eq!(start, 0);
}
#[test]
fn lru_order_respects_access() {
let l = 5;
let row_len = l;
let header = l * std::mem::size_of::<LruNode>() / std::mem::size_of::<Qfloat>();
let budget = 3 * row_len + header;
let bytes = budget * std::mem::size_of::<Qfloat>();
let mut cache = Cache::new(l, bytes);
let (d, _) = cache.get_data(0, row_len);
d[0] = 10.0;
let (d, _) = cache.get_data(1, row_len);
d[0] = 20.0;
let (d, _) = cache.get_data(2, row_len);
d[0] = 30.0;
let (d, start) = cache.get_data(0, row_len);
assert_eq!(start, row_len); assert_eq!(d[0], 10.0);
let (d, start) = cache.get_data(3, row_len);
assert_eq!(start, 0); d[0] = 40.0;
assert!(cache.data[1].is_none());
assert!(cache.data[0].is_some());
assert!(cache.data[2].is_some());
}
#[test]
fn swap_index_row_swap() {
let mut cache = Cache::new(3, 1000);
let (data, _) = cache.get_data(0, 3);
data[0] = 10.0;
data[1] = 20.0;
data[2] = 30.0;
cache.swap_index(0, 2);
let (data, start) = cache.get_data(2, 3);
assert_eq!(start, 3); assert_eq!(data[0], 30.0); assert_eq!(data[1], 20.0);
assert_eq!(data[2], 10.0);
let (_, start) = cache.get_data(0, 3);
assert_eq!(start, 0);
}
#[test]
fn swap_index_swaps_columns_in_other_rows() {
let mut cache = Cache::new(4, 10000);
let (data, _) = cache.get_data(0, 4);
data[0] = 1.0;
data[1] = 2.0;
data[2] = 3.0;
data[3] = 4.0;
let (data, _) = cache.get_data(1, 4);
data[0] = 10.0;
data[1] = 20.0;
data[2] = 30.0;
data[3] = 40.0;
cache.swap_index(1, 3);
let (data, start) = cache.get_data(0, 4);
assert_eq!(start, 4); assert_eq!(data[0], 1.0);
assert_eq!(data[1], 4.0); assert_eq!(data[2], 3.0);
assert_eq!(data[3], 2.0);
let (data, start) = cache.get_data(3, 4);
assert_eq!(start, 4); assert_eq!(data[0], 10.0);
assert_eq!(data[1], 40.0); assert_eq!(data[2], 30.0);
assert_eq!(data[3], 20.0); }
#[test]
fn swap_index_evicts_partial_rows() {
let mut cache = Cache::new(4, 10000);
let (data, _) = cache.get_data(0, 4);
data[0] = 1.0;
data[1] = 2.0;
data[2] = 3.0;
data[3] = 4.0;
let (data, _) = cache.get_data(1, 2);
data[0] = 10.0;
data[1] = 20.0;
cache.swap_index(1, 3);
let (data, start) = cache.get_data(0, 4);
assert_eq!(start, 4);
assert_eq!(data[1], 4.0);
assert_eq!(data[3], 2.0);
let (_, start) = cache.get_data(1, 2);
assert_eq!(start, 0);
}
}