use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, RwLock};
use nodedb_types::{DatabaseId, TenantId};
use crate::budget::Budget;
use crate::engine::EngineId;
use crate::error::{MemError, Result};
use crate::pressure::{PressureLevel, PressureThresholds};
use crate::reservation_token::ReservationToken;
pub struct GlobalCounter {
pub(crate) allocated: AtomicUsize,
pub(crate) ceiling: usize,
}
impl std::fmt::Debug for GlobalCounter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GlobalCounter")
.field("allocated", &self.allocated.load(Ordering::Relaxed))
.field("ceiling", &self.ceiling)
.finish()
}
}
#[derive(Debug)]
struct ScopedBudget {
limit: usize,
allocated: Arc<AtomicUsize>,
}
impl ScopedBudget {
fn new(limit: usize) -> Self {
Self {
limit,
allocated: Arc::new(AtomicUsize::new(0)),
}
}
fn try_reserve(&self, size: usize) -> Option<Arc<AtomicUsize>> {
loop {
let current = self.allocated.load(Ordering::Relaxed);
if current + size > self.limit {
return None;
}
match self.allocated.compare_exchange_weak(
current,
current + size,
Ordering::AcqRel,
Ordering::Relaxed,
) {
Ok(_) => return Some(Arc::clone(&self.allocated)),
Err(_) => continue,
}
}
}
fn available(&self) -> usize {
let alloc = self.allocated.load(Ordering::Relaxed);
self.limit.saturating_sub(alloc)
}
}
#[derive(Debug, Clone)]
pub struct GovernorConfig {
pub global_ceiling: usize,
pub engine_limits: HashMap<EngineId, usize>,
}
impl GovernorConfig {
pub fn validate(&self) -> Result<()> {
let total: usize = self.engine_limits.values().sum();
if total > self.global_ceiling {
return Err(MemError::GlobalCeilingExceeded {
allocated: total,
ceiling: self.global_ceiling,
requested: 0,
});
}
Ok(())
}
}
#[derive(Debug)]
pub struct MemoryGovernor {
budgets: HashMap<EngineId, Budget>,
global_counter: Arc<GlobalCounter>,
global_ceiling: usize,
thresholds: PressureThresholds,
database_budgets: RwLock<HashMap<DatabaseId, ScopedBudget>>,
tenant_budgets: RwLock<HashMap<(DatabaseId, TenantId), ScopedBudget>>,
}
impl MemoryGovernor {
pub fn new(config: GovernorConfig) -> Result<Self> {
config.validate()?;
let mut budgets = HashMap::new();
for (engine, limit) in &config.engine_limits {
budgets.insert(*engine, Budget::new(*limit));
}
let global_counter = Arc::new(GlobalCounter {
allocated: AtomicUsize::new(0),
ceiling: config.global_ceiling,
});
Ok(Self {
budgets,
global_counter,
global_ceiling: config.global_ceiling,
thresholds: PressureThresholds::default(),
database_budgets: RwLock::new(HashMap::new()),
tenant_budgets: RwLock::new(HashMap::new()),
})
}
pub fn set_database_budget(&self, db: DatabaseId, max_bytes: usize) {
let mut map = self
.database_budgets
.write()
.unwrap_or_else(|p| p.into_inner());
map.insert(db, ScopedBudget::new(max_bytes));
}
pub fn clear_database_budget(&self, db: DatabaseId) {
let mut map = self
.database_budgets
.write()
.unwrap_or_else(|p| p.into_inner());
map.remove(&db);
}
pub fn set_tenant_budget(&self, db: DatabaseId, tenant: TenantId, max_bytes: usize) {
let mut map = self
.tenant_budgets
.write()
.unwrap_or_else(|p| p.into_inner());
map.insert((db, tenant), ScopedBudget::new(max_bytes));
}
pub fn clear_tenant_budget(&self, db: DatabaseId, tenant: TenantId) {
let mut map = self
.tenant_budgets
.write()
.unwrap_or_else(|p| p.into_inner());
map.remove(&(db, tenant));
}
pub fn try_reserve(
&self,
db: DatabaseId,
tenant: TenantId,
engine: EngineId,
size: usize,
) -> Result<ReservationToken> {
let global_arc = Arc::clone(&self.global_counter);
if size > 0 {
loop {
let current = global_arc.allocated.load(Ordering::Relaxed);
if current + size > global_arc.ceiling {
return Err(MemError::GlobalCeilingExceeded {
allocated: current,
ceiling: global_arc.ceiling,
requested: size,
});
}
match global_arc.allocated.compare_exchange_weak(
current,
current + size,
Ordering::AcqRel,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(_) => continue,
}
}
}
let db_counter = {
let map = self
.database_budgets
.read()
.unwrap_or_else(|p| p.into_inner());
if let Some(budget) = map.get(&db) {
match budget.try_reserve(size) {
Some(arc) => Some(arc),
None => {
if size > 0 {
global_arc.allocated.fetch_sub(size, Ordering::Relaxed);
}
return Err(MemError::DatabaseBudgetExhausted {
db,
requested: size,
available: budget.available(),
limit: budget.limit,
});
}
}
} else {
None
}
};
let tenant_counter = {
let map = self
.tenant_budgets
.read()
.unwrap_or_else(|p| p.into_inner());
if let Some(budget) = map.get(&(db, tenant)) {
match budget.try_reserve(size) {
Some(arc) => Some(arc),
None => {
if let Some(ref ctr) = db_counter
&& size > 0
{
ctr.fetch_sub(size, Ordering::Relaxed);
}
if size > 0 {
global_arc.allocated.fetch_sub(size, Ordering::Relaxed);
}
return Err(MemError::TenantBudgetExhausted {
db,
tenant,
requested: size,
available: budget.available(),
limit: budget.limit,
});
}
}
} else {
None
}
};
let engine_budget = self
.budgets
.get(&engine)
.ok_or(MemError::UnknownEngine(engine))?;
let engine_counter = if let Some(arc) = engine_budget.try_reserve_arc(size) {
Some(arc)
} else {
if let Some(ref ctr) = tenant_counter
&& size > 0
{
ctr.fetch_sub(size, Ordering::Relaxed);
}
if let Some(ref ctr) = db_counter
&& size > 0
{
ctr.fetch_sub(size, Ordering::Relaxed);
}
if size > 0 {
global_arc.allocated.fetch_sub(size, Ordering::Relaxed);
}
return Err(MemError::BudgetExhausted {
engine,
requested: size,
available: engine_budget.available(),
limit: engine_budget.limit(),
});
};
Ok(ReservationToken::new(
crate::reservation_token::ReservationParams {
global_counter: global_arc,
database_counter: db_counter,
tenant_counter,
engine_counter,
size,
db,
tenant,
engine,
},
))
}
pub fn release(&self, engine: EngineId, size: usize) {
if let Some(budget) = self.budgets.get(&engine) {
budget.release(size);
}
crate::budget::atomic_saturating_sub(&self.global_counter.allocated, size);
}
pub fn budget(&self, engine: EngineId) -> Option<&Budget> {
self.budgets.get(&engine)
}
pub fn global_ceiling(&self) -> usize {
self.global_ceiling
}
pub fn total_allocated(&self) -> usize {
self.budgets.values().map(|b| b.allocated()).sum()
}
pub fn total_over_release_count(&self) -> usize {
self.budgets.values().map(|b| b.over_release_count()).sum()
}
pub fn global_utilization_percent(&self) -> u8 {
if self.global_ceiling == 0 {
return 100;
}
((self.total_allocated() as u128 * 100) / self.global_ceiling as u128).min(100) as u8
}
pub fn engine_pressure(&self, engine: EngineId) -> PressureLevel {
self.budgets
.get(&engine)
.map(|b| self.thresholds.level_for(b.utilization_percent()))
.unwrap_or(PressureLevel::Emergency)
}
pub fn global_pressure(&self) -> PressureLevel {
self.thresholds.level_for(self.global_utilization_percent())
}
pub fn worst_engine_pressure(&self) -> PressureLevel {
self.budgets
.values()
.map(|b| self.thresholds.level_for(b.utilization_percent()))
.max()
.unwrap_or(PressureLevel::Normal)
}
pub fn set_thresholds(&mut self, thresholds: PressureThresholds) {
self.thresholds = thresholds;
}
pub fn snapshot(&self) -> Vec<EngineSnapshot> {
self.budgets
.iter()
.map(|(engine, budget)| EngineSnapshot {
engine: *engine,
allocated: budget.allocated(),
limit: budget.limit(),
peak: budget.peak(),
rejections: budget.rejections(),
utilization_percent: budget.utilization_percent(),
})
.collect()
}
}
#[derive(Debug, Clone)]
pub struct EngineSnapshot {
pub engine: EngineId,
pub allocated: usize,
pub limit: usize,
pub peak: usize,
pub rejections: usize,
pub utilization_percent: u8,
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::sync::Arc;
use std::thread;
use nodedb_types::{DatabaseId, TenantId};
use super::*;
fn test_config() -> GovernorConfig {
let mut engine_limits = HashMap::new();
engine_limits.insert(EngineId::Vector, 4096);
engine_limits.insert(EngineId::Query, 2048);
engine_limits.insert(EngineId::Timeseries, 1024);
GovernorConfig {
global_ceiling: 8192,
engine_limits,
}
}
fn db() -> DatabaseId {
DatabaseId::DEFAULT
}
fn tenant() -> TenantId {
TenantId::new(1)
}
#[test]
fn reserve_within_budget() {
let gov = MemoryGovernor::new(test_config()).unwrap();
let tok = gov
.try_reserve(db(), tenant(), EngineId::Vector, 1000)
.unwrap();
assert_eq!(gov.budget(EngineId::Vector).unwrap().allocated(), 1000);
assert_eq!(tok.size(), 1000);
}
#[test]
fn reserve_exceeds_engine_budget() {
let gov = MemoryGovernor::new(test_config()).unwrap();
let err = gov
.try_reserve(db(), tenant(), EngineId::Query, 3000)
.unwrap_err();
assert!(matches!(err, MemError::BudgetExhausted { .. }));
}
#[test]
fn reserve_exceeds_global_ceiling() {
let gov = MemoryGovernor::new(test_config()).unwrap();
let _t1 = gov
.try_reserve(db(), tenant(), EngineId::Vector, 4096)
.unwrap();
let _t2 = gov
.try_reserve(db(), tenant(), EngineId::Query, 2048)
.unwrap();
let _t3 = gov
.try_reserve(db(), tenant(), EngineId::Timeseries, 1024)
.unwrap();
let err = gov
.try_reserve(db(), tenant(), EngineId::Timeseries, 2000)
.unwrap_err();
assert!(matches!(
err,
MemError::BudgetExhausted { .. } | MemError::GlobalCeilingExceeded { .. }
));
}
#[test]
fn raii_release_returns_to_baseline() {
let gov = MemoryGovernor::new(test_config()).unwrap();
{
let tok = gov
.try_reserve(db(), tenant(), EngineId::Vector, 1000)
.unwrap();
assert_eq!(gov.budget(EngineId::Vector).unwrap().allocated(), 1000);
assert_eq!(tok.size(), 1000);
}
assert_eq!(
gov.budget(EngineId::Vector).unwrap().allocated(),
0,
"engine counter must be returned on drop"
);
}
#[test]
fn database_cap_denies_even_with_tenant_headroom() {
let gov = MemoryGovernor::new(test_config()).unwrap();
gov.set_database_budget(db(), 500);
gov.set_tenant_budget(db(), tenant(), 4096);
let err = gov
.try_reserve(db(), tenant(), EngineId::Vector, 600)
.unwrap_err();
assert!(
matches!(err, MemError::DatabaseBudgetExhausted { .. }),
"expected DatabaseBudgetExhausted, got {err:?}"
);
}
#[test]
fn global_cap_denies_even_with_database_and_tenant_headroom() {
let mut engine_limits = HashMap::new();
engine_limits.insert(EngineId::Vector, 200);
let gov = MemoryGovernor::new(GovernorConfig {
global_ceiling: 200,
engine_limits,
})
.unwrap();
gov.set_database_budget(db(), 1024);
gov.set_tenant_budget(db(), tenant(), 1024);
let err = gov
.try_reserve(db(), tenant(), EngineId::Vector, 300)
.unwrap_err();
assert!(
matches!(err, MemError::GlobalCeilingExceeded { .. }),
"expected GlobalCeilingExceeded, got {err:?}"
);
}
#[test]
fn tenant_cap_denies_with_db_headroom() {
let gov = MemoryGovernor::new(test_config()).unwrap();
gov.set_database_budget(db(), 4096);
gov.set_tenant_budget(db(), tenant(), 300);
let err = gov
.try_reserve(db(), tenant(), EngineId::Vector, 400)
.unwrap_err();
assert!(
matches!(err, MemError::TenantBudgetExhausted { .. }),
"expected TenantBudgetExhausted, got {err:?}"
);
}
#[test]
fn partial_increments_rolled_back_on_db_failure() {
let gov = MemoryGovernor::new(test_config()).unwrap();
gov.set_database_budget(db(), 50);
let _ = gov
.try_reserve(db(), tenant(), EngineId::Vector, 100)
.unwrap_err();
assert_eq!(
gov.global_counter.allocated.load(Ordering::Relaxed),
0,
"global counter must be rolled back on database-layer failure"
);
}
#[test]
fn partial_increments_rolled_back_on_tenant_failure() {
let gov = MemoryGovernor::new(test_config()).unwrap();
gov.set_database_budget(db(), 4096);
gov.set_tenant_budget(db(), tenant(), 50);
let _ = gov
.try_reserve(db(), tenant(), EngineId::Vector, 100)
.unwrap_err();
assert_eq!(
gov.global_counter.allocated.load(Ordering::Relaxed),
0,
"global counter must be rolled back on tenant-layer failure"
);
let db_map = gov.database_budgets.read().unwrap();
let db_alloc = db_map[&db()].allocated.load(Ordering::Relaxed);
assert_eq!(db_alloc, 0, "database counter must be rolled back");
}
#[test]
fn concurrent_reserves_never_exceed_cap() {
let mut limits = HashMap::new();
limits.insert(EngineId::Vector, 10_000);
let gov = Arc::new(
MemoryGovernor::new(GovernorConfig {
global_ceiling: 10_000,
engine_limits: limits,
})
.unwrap(),
);
gov.set_database_budget(DatabaseId::DEFAULT, 10_000);
let n_threads = 8;
let reserve_size = 1_000;
let mut handles = Vec::new();
for i in 0..n_threads {
let gov_clone = Arc::clone(&gov);
handles.push(thread::spawn(move || {
gov_clone.try_reserve(
DatabaseId::DEFAULT,
TenantId::new(i as u64),
EngineId::Vector,
reserve_size,
)
}));
}
let results: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
let successful: Vec<_> = results.into_iter().filter_map(|r| r.ok()).collect();
assert!(
successful.len() <= 10,
"expected at most 10 successful reservations, got {}",
successful.len()
);
let engine_alloc = gov.budget(EngineId::Vector).unwrap().allocated();
assert!(
engine_alloc <= 10_000,
"engine total {engine_alloc} must not exceed cap 10000"
);
let global_alloc = gov.global_counter.allocated.load(Ordering::Relaxed);
assert!(
global_alloc <= 10_000,
"global total {global_alloc} must not exceed ceiling 10000"
);
}
#[test]
fn unknown_engine_rejected() {
let gov = MemoryGovernor::new(test_config()).unwrap();
let err = gov
.try_reserve(db(), tenant(), EngineId::Crdt, 100)
.unwrap_err();
assert!(matches!(err, MemError::UnknownEngine(EngineId::Crdt)));
}
#[test]
fn snapshot_reports_all_engines() {
let gov = MemoryGovernor::new(test_config()).unwrap();
let _tok = gov
.try_reserve(db(), tenant(), EngineId::Vector, 2048)
.unwrap();
let snap = gov.snapshot();
assert_eq!(snap.len(), 3);
let vector_snap = snap.iter().find(|s| s.engine == EngineId::Vector).unwrap();
assert_eq!(vector_snap.allocated, 2048);
assert_eq!(vector_snap.limit, 4096);
assert_eq!(vector_snap.utilization_percent, 50);
}
#[test]
fn engine_pressure_levels() {
let gov = MemoryGovernor::new(test_config()).unwrap();
assert_eq!(gov.engine_pressure(EngineId::Vector), PressureLevel::Normal);
let _tok1 = gov
.try_reserve(db(), tenant(), EngineId::Vector, 2868)
.unwrap();
assert_eq!(
gov.engine_pressure(EngineId::Vector),
PressureLevel::Warning
);
}
#[test]
fn worst_engine_pressure_picks_highest() {
let gov = MemoryGovernor::new(test_config()).unwrap();
assert_eq!(gov.worst_engine_pressure(), PressureLevel::Normal);
let _tok = gov
.try_reserve(db(), tenant(), EngineId::Query, 1800)
.unwrap();
assert_eq!(gov.engine_pressure(EngineId::Vector), PressureLevel::Normal);
assert_eq!(gov.worst_engine_pressure(), PressureLevel::Critical);
}
#[test]
fn invalid_config_rejected() {
let mut config = test_config();
config.global_ceiling = 100;
assert!(MemoryGovernor::new(config).is_err());
}
}