use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Mutex;
use crate::layout::BlobGuid;
const SHARDS: usize = 16;
struct Entry {
compact_times: u32,
region: Box<[u8]>,
}
#[derive(Default)]
struct Shard {
map: HashMap<BlobGuid, Entry>,
bytes: usize,
}
pub(crate) struct RoutingCache {
shards: Box<[Mutex<Shard>]>,
shard_budget_bytes: usize,
hits: AtomicU64,
misses: AtomicU64,
}
impl RoutingCache {
pub(crate) fn new(total_budget_bytes: usize) -> Self {
let shard_budget_bytes = (total_budget_bytes / SHARDS).max(64 * 1024);
let shards = (0..SHARDS)
.map(|_| Mutex::new(Shard::default()))
.collect::<Vec<_>>()
.into_boxed_slice();
Self {
shards,
shard_budget_bytes,
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
}
}
#[inline]
fn shard(&self, guid: &BlobGuid) -> &Mutex<Shard> {
let h = u64::from_le_bytes(guid[8..16].try_into().unwrap());
&self.shards[(h as usize) & (SHARDS - 1)]
}
pub(crate) fn fill(&self, guid: BlobGuid, compact_times: u32, dst: &mut [u8]) -> bool {
let mut shard = self.shard(&guid).lock().unwrap();
match shard.map.get(&guid) {
Some(e) if e.compact_times == compact_times && e.region.len() == dst.len() => {
dst.copy_from_slice(&e.region);
self.hits.fetch_add(1, Ordering::Relaxed);
true
}
Some(_) => {
if let Some(e) = shard.map.remove(&guid) {
shard.bytes -= e.region.len();
}
self.misses.fetch_add(1, Ordering::Relaxed);
false
}
None => {
self.misses.fetch_add(1, Ordering::Relaxed);
false
}
}
}
pub(crate) fn put(&self, guid: BlobGuid, compact_times: u32, region: &[u8]) {
let mut shard = self.shard(&guid).lock().unwrap();
if let Some(old) = shard.map.remove(&guid) {
shard.bytes -= old.region.len();
}
if shard.bytes + region.len() > self.shard_budget_bytes {
shard.map.clear();
shard.bytes = 0;
}
shard.bytes += region.len();
shard.map.insert(
guid,
Entry {
compact_times,
region: region.into(),
},
);
}
#[cfg(test)]
pub(crate) fn hits(&self) -> u64 {
self.hits.load(Ordering::Relaxed)
}
#[cfg(test)]
pub(crate) fn misses(&self) -> u64 {
self.misses.load(Ordering::Relaxed)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn guid(n: u8) -> BlobGuid {
let mut g = [0u8; 16];
g[8] = n; g[0] = n;
g
}
#[test]
fn hit_only_on_matching_compact_times() {
let c = RoutingCache::new(1 << 20);
let g = guid(1);
c.put(g, 5, &[0xAB; 64]);
let mut dst = [0u8; 64];
assert!(c.fill(g, 5, &mut dst), "matching compact_times hits");
assert_eq!(dst, [0xAB; 64]);
assert_eq!(c.hits(), 1);
assert!(!c.fill(g, 6, &mut [0u8; 64]), "stale compact_times misses");
assert!(!c.fill(g, 5, &mut [0u8; 64]), "stale entry evicted");
assert_eq!(c.misses(), 2);
}
#[test]
fn put_refreshes_to_new_generation() {
let c = RoutingCache::new(1 << 20);
let g = guid(2);
c.put(g, 1, &[1u8; 32]);
c.put(g, 2, &[2u8; 48]); let mut dst = [0u8; 48];
assert!(c.fill(g, 2, &mut dst));
assert_eq!(dst, [2u8; 48]);
assert!(!c.fill(g, 1, &mut [0u8; 32]), "old generation gone");
}
#[test]
fn stays_bounded_under_overflow() {
let c = RoutingCache::new(SHARDS * 64 * 1024); let region = vec![0u8; 8192];
for n in 0..2000u32 {
let mut g = [0u8; 16];
g[8..12].copy_from_slice(&n.to_le_bytes());
c.put(g, 1, ®ion);
}
for s in &c.shards {
let s = s.lock().unwrap();
assert!(
s.bytes <= c.shard_budget_bytes,
"shard over budget: {} > {}",
s.bytes,
c.shard_budget_bytes
);
}
}
}