use std::cell::Cell;
use std::collections::{HashMap, VecDeque};
#[derive(Eq, PartialEq, Hash, Clone)]
struct CacheKey {
database_id: u64,
tenant_id: u64,
collection: String,
document_id: String,
}
struct DatabaseShard {
entries: HashMap<CacheKey, Vec<u8>>,
order: VecDeque<CacheKey>,
tenant_counts: HashMap<u64, usize>,
weight: u32,
}
impl DatabaseShard {
fn new(weight: u32) -> Self {
Self {
entries: HashMap::new(),
order: VecDeque::new(),
tenant_counts: HashMap::new(),
weight: weight.max(1),
}
}
fn overshoot_score(&self, total_weight: u64) -> u64 {
(self.entries.len() as u64)
.saturating_mul(total_weight)
.saturating_div(self.weight as u64)
}
fn evict_one(&mut self) -> bool {
while let Some(evicted) = self.order.pop_front() {
if self.entries.remove(&evicted).is_some() {
if let Some(count) = self.tenant_counts.get_mut(&evicted.tenant_id) {
*count = count.saturating_sub(1);
}
return true;
}
}
false
}
}
pub struct DocCache {
shards: HashMap<u64, DatabaseShard>,
total: usize,
capacity: usize,
hits: Cell<u64>,
misses: Cell<u64>,
}
impl DocCache {
pub fn new(capacity: usize) -> Self {
Self {
shards: HashMap::new(),
total: 0,
capacity,
hits: Cell::new(0),
misses: Cell::new(0),
}
}
pub fn set_database_weight(&mut self, database_id: u64, weight: u32) {
self.shards
.entry(database_id)
.or_insert_with(|| DatabaseShard::new(weight))
.weight = weight.max(1);
}
pub fn get(
&self,
database_id: u64,
tenant_id: u64,
collection: &str,
document_id: &str,
) -> Option<&[u8]> {
let key = Self::make_key(database_id, tenant_id, collection, document_id);
match self
.shards
.get(&database_id)
.and_then(|shard| shard.entries.get(&key))
{
Some(v) => {
self.hits.set(self.hits.get() + 1);
Some(v.as_slice())
}
None => {
self.misses.set(self.misses.get() + 1);
None
}
}
}
pub fn put(
&mut self,
database_id: u64,
tenant_id: u64,
collection: &str,
document_id: &str,
value: &[u8],
) {
let key = Self::make_key(database_id, tenant_id, collection, document_id);
self.shards
.entry(database_id)
.or_insert_with(|| DatabaseShard::new(1));
{
let shard = self.shards.get_mut(&database_id).expect("just inserted");
if let Some(existing) = shard.entries.get_mut(&key) {
*existing = value.to_vec();
return;
}
}
if self.total >= self.capacity {
loop {
if !self.evict_from_highest_overshoot(database_id) {
break;
}
let total_weight = self.total_weight();
let fair_share = self
.shards
.get(&database_id)
.map(|s| {
(self.capacity as u64)
.saturating_mul(s.weight as u64)
.saturating_div(total_weight) as usize
})
.unwrap_or(0);
let count = self
.shards
.get(&database_id)
.map(|s| s.entries.len())
.unwrap_or(0);
if count <= fair_share {
break;
}
}
}
let shard = self.shards.get_mut(&database_id).expect("just inserted");
shard.entries.insert(key.clone(), value.to_vec());
shard.order.push_back(key.clone());
*shard.tenant_counts.entry(tenant_id).or_insert(0) += 1;
self.total += 1;
}
pub fn invalidate(
&mut self,
database_id: u64,
tenant_id: u64,
collection: &str,
document_id: &str,
) {
let key = Self::make_key(database_id, tenant_id, collection, document_id);
let removed = self
.shards
.get_mut(&database_id)
.and_then(|shard| {
shard.entries.remove(&key).map(|_| {
if let Some(count) = shard.tenant_counts.get_mut(&tenant_id) {
*count = count.saturating_sub(1);
}
})
})
.is_some();
if removed {
self.total = self.total.saturating_sub(1);
}
}
pub fn evict_collection(&mut self, database_id: u64, tenant_id: u64, collection: &str) {
let shard = match self.shards.get_mut(&database_id) {
Some(s) => s,
None => return,
};
let before = shard.entries.len();
shard
.entries
.retain(|k, _| !(k.tenant_id == tenant_id && k.collection == collection));
let after = shard.entries.len();
let removed = before.saturating_sub(after);
shard
.order
.retain(|k| !(k.tenant_id == tenant_id && k.collection == collection));
if removed > 0 {
if let Some(count) = shard.tenant_counts.get_mut(&tenant_id) {
*count = count.saturating_sub(removed);
}
self.total = self.total.saturating_sub(removed);
}
}
pub fn evict_tenant(&mut self, database_id: u64, tenant_id: u64) {
let shard = match self.shards.get_mut(&database_id) {
Some(s) => s,
None => return,
};
let before = shard.entries.len();
shard.entries.retain(|k, _| k.tenant_id != tenant_id);
shard.order.retain(|k| k.tenant_id != tenant_id);
shard.tenant_counts.remove(&tenant_id);
let removed = before.saturating_sub(shard.entries.len());
self.total = self.total.saturating_sub(removed);
}
pub fn hit_rate(&self) -> f64 {
let hits = self.hits.get();
let total = hits + self.misses.get();
if total == 0 {
0.0
} else {
hits as f64 / total as f64
}
}
pub fn len(&self) -> usize {
self.total
}
pub fn is_empty(&self) -> bool {
self.total == 0
}
pub fn total_lookups(&self) -> u64 {
self.hits.get() + self.misses.get()
}
fn make_key(database_id: u64, tenant_id: u64, collection: &str, document_id: &str) -> CacheKey {
CacheKey {
database_id,
tenant_id,
collection: collection.to_string(),
document_id: document_id.to_string(),
}
}
fn total_weight(&self) -> u64 {
self.shards
.values()
.map(|s| s.weight as u64)
.sum::<u64>()
.max(1)
}
fn evict_from_highest_overshoot(&mut self, hint_db_id: u64) -> bool {
let total_weight = self.total_weight();
let db_id = self
.shards
.iter()
.filter(|(_, s)| !s.entries.is_empty())
.max_by_key(|(id, s)| {
let score = s.overshoot_score(total_weight);
let is_hint = u8::from(**id == hint_db_id);
(score, is_hint)
})
.map(|(&id, _)| id);
match db_id {
Some(id) => {
let removed = self.shards.get_mut(&id).expect("just found").evict_one();
if removed {
self.total = self.total.saturating_sub(1);
}
removed
}
None => false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_put_get() {
let mut cache = DocCache::new(16);
cache.put(0, 1, "users", "u1", b"alice");
assert_eq!(cache.get(0, 1, "users", "u1"), Some(b"alice".as_slice()));
assert_eq!(cache.get(0, 1, "users", "u2"), None);
}
#[test]
fn overwrite_updates_value() {
let mut cache = DocCache::new(16);
cache.put(0, 1, "users", "u1", b"alice");
cache.put(0, 1, "users", "u1", b"ALICE");
assert_eq!(cache.get(0, 1, "users", "u1"), Some(b"ALICE".as_slice()));
}
#[test]
fn invalidate_removes_entry() {
let mut cache = DocCache::new(16);
cache.put(0, 1, "users", "u1", b"alice");
cache.invalidate(0, 1, "users", "u1");
assert_eq!(cache.get(0, 1, "users", "u1"), None);
}
#[test]
fn tenant_isolation() {
let mut cache = DocCache::new(16);
cache.put(0, 1, "users", "u1", b"tenant1");
cache.put(0, 2, "users", "u1", b"tenant2");
assert_eq!(cache.get(0, 1, "users", "u1"), Some(b"tenant1".as_slice()));
assert_eq!(cache.get(0, 2, "users", "u1"), Some(b"tenant2".as_slice()));
}
#[test]
fn hit_rate_tracking() {
let mut cache = DocCache::new(16);
cache.put(0, 1, "c", "a", b"1");
cache.get(0, 1, "c", "a"); cache.get(0, 1, "c", "a"); cache.get(0, 1, "c", "b"); assert!((cache.hit_rate() - 0.6667).abs() < 0.01);
assert_eq!(cache.total_lookups(), 3);
}
#[test]
fn cache_key_uniqueness_across_databases() {
let mut cache = DocCache::new(16);
cache.put(1, 5, "col", "doc", b"db1");
cache.put(2, 5, "col", "doc", b"db2");
assert_eq!(cache.get(1, 5, "col", "doc"), Some(b"db1".as_slice()));
assert_eq!(cache.get(2, 5, "col", "doc"), Some(b"db2".as_slice()));
}
#[test]
fn hot_db_does_not_evict_cold_db_below_proportional_share() {
let mut cache = DocCache::new(8);
cache.set_database_weight(1, 1);
cache.set_database_weight(2, 1);
for i in 0..4u32 {
cache.put(1, 1, "c", &format!("db1-{i}"), b"v");
}
for i in 0..4u32 {
cache.put(2, 1, "c", &format!("db2-{i}"), b"v");
}
assert_eq!(cache.len(), 8);
for i in 4..8u32 {
cache.put(1, 1, "c", &format!("db1-{i}"), b"v");
}
let db2_resident: usize = (0..4u32)
.filter(|i| cache.get(2, 1, "c", &format!("db2-{i}")).is_some())
.count();
assert_eq!(
db2_resident, 4,
"DB2 should retain all 4 entries; resident={db2_resident}"
);
}
#[test]
fn weight_ratio_affects_resident_sets() {
let capacity = 10;
let mut cache = DocCache::new(capacity);
cache.set_database_weight(1, 1);
cache.set_database_weight(2, 4);
for i in 0..5u32 {
cache.put(1, 1, "c", &format!("a{i}"), b"v");
cache.put(2, 1, "c", &format!("b{i}"), b"v");
}
assert_eq!(cache.len(), capacity);
for i in 5..10u32 {
cache.put(1, 1, "c", &format!("a{i}"), b"v");
}
let db1_count = (0..10u32)
.filter(|i| cache.get(1, 1, "c", &format!("a{i}")).is_some())
.count();
let db2_count = (0..5u32)
.filter(|i| cache.get(2, 1, "c", &format!("b{i}")).is_some())
.count();
assert!(
db2_count > db1_count,
"DB2 (weight=4) should have more resident entries than DB1 (weight=1); db2={db2_count} db1={db1_count}"
);
}
#[test]
fn evict_collection_removes_correct_entries() {
let mut cache = DocCache::new(16);
cache.put(1, 1, "col_a", "d1", b"1");
cache.put(1, 1, "col_b", "d1", b"2");
cache.put(2, 1, "col_a", "d1", b"3");
cache.evict_collection(1, 1, "col_a");
assert_eq!(cache.get(1, 1, "col_a", "d1"), None);
assert_eq!(cache.get(1, 1, "col_b", "d1"), Some(b"2".as_slice()));
assert_eq!(cache.get(2, 1, "col_a", "d1"), Some(b"3".as_slice()));
}
#[test]
fn evict_tenant_removes_correct_entries() {
let mut cache = DocCache::new(16);
cache.put(1, 1, "col", "d1", b"t1");
cache.put(1, 2, "col", "d1", b"t2");
cache.put(2, 1, "col", "d1", b"db2");
cache.evict_tenant(1, 1);
assert_eq!(cache.get(1, 1, "col", "d1"), None);
assert_eq!(cache.get(1, 2, "col", "d1"), Some(b"t2".as_slice()));
assert_eq!(cache.get(2, 1, "col", "d1"), Some(b"db2".as_slice()));
}
}