use dashmap::DashMap;
use parking_lot::Mutex;
use std::{
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
time::{Duration, Instant},
};
pub const DEFAULT_LOAD_TIMEOUT: Duration = Duration::from_secs(30);
pub struct TenantLoader {
loaded: DashMap<String, ()>,
locks: DashMap<String, Arc<Mutex<()>>>,
bytes: DashMap<String, u64>,
last_used: DashMap<String, Instant>,
byte_budget: AtomicU64,
load_timeout: Duration,
}
impl TenantLoader {
pub fn new() -> Self {
Self::with_timeout(DEFAULT_LOAD_TIMEOUT)
}
pub fn with_timeout(load_timeout: Duration) -> Self {
Self {
loaded: DashMap::new(),
locks: DashMap::new(),
bytes: DashMap::new(),
last_used: DashMap::new(),
byte_budget: AtomicU64::new(0),
load_timeout,
}
}
pub fn set_byte_budget(&self, budget: u64) {
self.byte_budget.store(budget, Ordering::Relaxed);
}
pub fn byte_budget(&self) -> u64 {
self.byte_budget.load(Ordering::Relaxed)
}
pub fn over_budget(&self) -> bool {
let b = self.byte_budget();
b != 0 && self.total_bytes() > b
}
pub fn touch(&self, tenant_id: &str) {
self.last_used.insert(tenant_id.to_string(), Instant::now());
}
pub fn pick_lru_excluding(&self, excluded: &str) -> Option<String> {
let mut victim: Option<(String, Instant)> = None;
for kv in &self.loaded {
let tenant = kv.key();
if tenant == excluded {
continue;
}
let Some(last) = self.last_used.get(tenant).map(|v| *v) else {
tracing::warn!(
tenant_id = %tenant,
"loaded tenant missing last_used stamp — skipping in LRU pick"
);
continue;
};
match &victim {
None => victim = Some((tenant.clone(), last)),
Some((_, t)) if last < *t => victim = Some((tenant.clone(), last)),
_ => {}
}
}
victim.map(|(t, _)| t)
}
pub fn is_loaded(&self, tenant_id: &str) -> bool {
self.loaded.contains_key(tenant_id)
}
pub fn mark_loaded(&self, tenant_id: &str) {
self.loaded.insert(tenant_id.to_string(), ());
self.last_used.insert(tenant_id.to_string(), Instant::now());
}
pub fn mark_unloaded(&self, tenant_id: &str) {
self.loaded.remove(tenant_id);
self.bytes.remove(tenant_id);
self.last_used.remove(tenant_id);
}
pub fn add_bytes(&self, tenant_id: &str, n: u64) {
*self.bytes.entry(tenant_id.to_string()).or_insert(0) += n;
}
pub fn bytes_for(&self, tenant_id: &str) -> u64 {
self.bytes.get(tenant_id).map_or(0, |v| *v)
}
pub fn total_bytes(&self) -> u64 {
self.bytes.iter().map(|kv| *kv.value()).sum()
}
pub fn bytes_per_tenant(&self) -> Vec<(String, u64)> {
self.bytes
.iter()
.map(|kv| (kv.key().clone(), *kv.value()))
.collect()
}
pub fn lock_for(&self, tenant_id: &str) -> Arc<Mutex<()>> {
self.locks
.entry(tenant_id.to_string())
.or_insert_with(|| Arc::new(Mutex::new(())))
.clone()
}
pub fn load_timeout(&self) -> Duration {
self.load_timeout
}
#[cfg(test)]
pub fn loaded_count(&self) -> usize {
self.loaded.len()
}
}
impl Default for TenantLoader {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{
sync::atomic::{AtomicUsize, Ordering},
thread,
};
#[test]
fn test_is_loaded_false_until_marked() {
let loader = TenantLoader::new();
assert!(!loader.is_loaded("alice"));
loader.mark_loaded("alice");
assert!(loader.is_loaded("alice"));
assert!(!loader.is_loaded("bob"));
}
#[test]
fn test_mark_loaded_is_idempotent() {
let loader = TenantLoader::new();
loader.mark_loaded("alice");
loader.mark_loaded("alice");
assert_eq!(loader.loaded_count(), 1);
}
#[test]
fn test_lock_for_returns_same_mutex_per_tenant() {
let loader = TenantLoader::new();
let lock_a1 = loader.lock_for("alice");
let lock_a2 = loader.lock_for("alice");
let lock_b = loader.lock_for("bob");
assert!(Arc::ptr_eq(&lock_a1, &lock_a2));
assert!(!Arc::ptr_eq(&lock_a1, &lock_b));
}
#[test]
fn test_singleflight_blocks_second_caller_until_first_releases() {
let loader = Arc::new(TenantLoader::new());
let progress = Arc::new(AtomicUsize::new(0));
let l1 = loader.clone();
let p1 = progress.clone();
let t1 = thread::spawn(move || {
let lock = l1.lock_for("alice");
let _g = lock.lock();
p1.store(1, Ordering::SeqCst);
thread::sleep(Duration::from_millis(80));
p1.store(2, Ordering::SeqCst);
l1.mark_loaded("alice");
});
while progress.load(Ordering::SeqCst) < 1 {
thread::sleep(Duration::from_millis(5));
}
let l2 = loader.clone();
let p2 = progress.clone();
let t2 = thread::spawn(move || {
let lock = l2.lock_for("alice");
let _g = lock.lock();
assert_eq!(
p2.load(Ordering::SeqCst),
2,
"second caller acquired lock before first finished — singleflight broken"
);
assert!(l2.is_loaded("alice"));
});
t1.join().unwrap();
t2.join().unwrap();
}
#[test]
fn test_mark_unloaded_clears_loaded_and_bytes() {
let loader = TenantLoader::new();
loader.mark_loaded("alice");
loader.add_bytes("alice", 100);
loader.mark_loaded("bob");
loader.add_bytes("bob", 200);
loader.mark_unloaded("alice");
assert!(!loader.is_loaded("alice"));
assert_eq!(loader.bytes_for("alice"), 0);
assert!(loader.is_loaded("bob"));
assert_eq!(loader.bytes_for("bob"), 200);
assert_eq!(loader.total_bytes(), 200);
}
#[test]
fn test_bytes_default_to_zero() {
let loader = TenantLoader::new();
assert_eq!(loader.bytes_for("alice"), 0);
assert_eq!(loader.total_bytes(), 0);
assert!(loader.bytes_per_tenant().is_empty());
}
#[test]
fn test_add_bytes_accumulates_per_tenant() {
let loader = TenantLoader::new();
loader.add_bytes("alice", 100);
loader.add_bytes("alice", 50);
loader.add_bytes("bob", 200);
assert_eq!(loader.bytes_for("alice"), 150);
assert_eq!(loader.bytes_for("bob"), 200);
assert_eq!(loader.total_bytes(), 350);
let mut snapshot = loader.bytes_per_tenant();
snapshot.sort();
assert_eq!(
snapshot,
vec![("alice".to_string(), 150), ("bob".to_string(), 200)]
);
}
#[test]
fn test_pick_lru_returns_oldest_excluding_target() {
let loader = TenantLoader::new();
loader.mark_loaded("alice");
std::thread::sleep(std::time::Duration::from_millis(5));
loader.mark_loaded("bob");
std::thread::sleep(std::time::Duration::from_millis(5));
loader.mark_loaded("carol");
assert_eq!(
loader.pick_lru_excluding("carol"),
Some("alice".to_string())
);
loader.touch("alice");
assert_eq!(loader.pick_lru_excluding("carol"), Some("bob".to_string()));
}
#[test]
fn test_pick_lru_returns_none_when_only_excluded_is_loaded() {
let loader = TenantLoader::new();
loader.mark_loaded("alice");
assert_eq!(loader.pick_lru_excluding("alice"), None);
}
#[test]
fn test_pick_lru_returns_none_when_nothing_loaded() {
let loader = TenantLoader::new();
assert_eq!(loader.pick_lru_excluding("anyone"), None);
}
#[test]
fn test_over_budget_zero_budget_means_disabled() {
let loader = TenantLoader::new();
loader.mark_loaded("alice");
loader.add_bytes("alice", 1_000_000_000);
assert!(!loader.over_budget());
loader.set_byte_budget(100);
assert!(loader.over_budget());
loader.set_byte_budget(2_000_000_000);
assert!(!loader.over_budget());
}
#[test]
fn test_lock_for_distinct_tenants_does_not_serialize() {
let loader = Arc::new(TenantLoader::new());
let l1 = loader.clone();
let alice_lock = l1.lock_for("alice");
let _alice_held = alice_lock.lock();
let bob_lock = loader.lock_for("bob");
let bob_held = bob_lock.try_lock();
assert!(
bob_held.is_some(),
"bob's lock should not be blocked by alice's"
);
}
}