1use std::collections::HashMap;
24use std::sync::atomic::{AtomicUsize, Ordering};
25use std::sync::{Arc, RwLock};
26
27use nodedb_types::{DatabaseId, TenantId};
28
29use crate::budget::Budget;
30use crate::engine::EngineId;
31use crate::error::{MemError, Result};
32use crate::pressure::{PressureLevel, PressureThresholds};
33use crate::reservation_token::ReservationToken;
34
35pub struct GlobalCounter {
40 pub(crate) allocated: AtomicUsize,
41 pub(crate) ceiling: usize,
42}
43
44impl std::fmt::Debug for GlobalCounter {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 f.debug_struct("GlobalCounter")
47 .field("allocated", &self.allocated.load(Ordering::Relaxed))
48 .field("ceiling", &self.ceiling)
49 .finish()
50 }
51}
52
53#[derive(Debug)]
57struct ScopedBudget {
58 limit: usize,
59 allocated: Arc<AtomicUsize>,
60}
61
62impl ScopedBudget {
63 fn new(limit: usize) -> Self {
64 Self {
65 limit,
66 allocated: Arc::new(AtomicUsize::new(0)),
67 }
68 }
69
70 fn try_reserve(&self, size: usize) -> Option<Arc<AtomicUsize>> {
73 loop {
74 let current = self.allocated.load(Ordering::Relaxed);
75 if current + size > self.limit {
76 return None;
77 }
78 match self.allocated.compare_exchange_weak(
79 current,
80 current + size,
81 Ordering::AcqRel,
82 Ordering::Relaxed,
83 ) {
84 Ok(_) => return Some(Arc::clone(&self.allocated)),
85 Err(_) => continue,
86 }
87 }
88 }
89
90 fn available(&self) -> usize {
91 let alloc = self.allocated.load(Ordering::Relaxed);
92 self.limit.saturating_sub(alloc)
93 }
94}
95
96#[derive(Debug, Clone)]
98pub struct GovernorConfig {
99 pub global_ceiling: usize,
102
103 pub engine_limits: HashMap<EngineId, usize>,
105}
106
107impl GovernorConfig {
108 pub fn validate(&self) -> Result<()> {
110 let total: usize = self.engine_limits.values().sum();
111 if total > self.global_ceiling {
112 return Err(MemError::GlobalCeilingExceeded {
113 allocated: total,
114 ceiling: self.global_ceiling,
115 requested: 0,
116 });
117 }
118 Ok(())
119 }
120}
121
122#[derive(Debug)]
128pub struct MemoryGovernor {
129 budgets: HashMap<EngineId, Budget>,
131
132 global_counter: Arc<GlobalCounter>,
134
135 global_ceiling: usize,
137
138 thresholds: PressureThresholds,
140
141 database_budgets: RwLock<HashMap<DatabaseId, ScopedBudget>>,
144
145 tenant_budgets: RwLock<HashMap<(DatabaseId, TenantId), ScopedBudget>>,
148}
149
150impl MemoryGovernor {
151 pub fn new(config: GovernorConfig) -> Result<Self> {
153 config.validate()?;
154
155 let mut budgets = HashMap::new();
156 for (engine, limit) in &config.engine_limits {
157 budgets.insert(*engine, Budget::new(*limit));
158 }
159
160 let global_counter = Arc::new(GlobalCounter {
161 allocated: AtomicUsize::new(0),
162 ceiling: config.global_ceiling,
163 });
164
165 Ok(Self {
166 budgets,
167 global_counter,
168 global_ceiling: config.global_ceiling,
169 thresholds: PressureThresholds::default(),
170 database_budgets: RwLock::new(HashMap::new()),
171 tenant_budgets: RwLock::new(HashMap::new()),
172 })
173 }
174
175 pub fn set_database_budget(&self, db: DatabaseId, max_bytes: usize) {
183 let mut map = self
184 .database_budgets
185 .write()
186 .unwrap_or_else(|p| p.into_inner());
187 map.insert(db, ScopedBudget::new(max_bytes));
188 }
189
190 pub fn clear_database_budget(&self, db: DatabaseId) {
192 let mut map = self
193 .database_budgets
194 .write()
195 .unwrap_or_else(|p| p.into_inner());
196 map.remove(&db);
197 }
198
199 pub fn set_tenant_budget(&self, db: DatabaseId, tenant: TenantId, max_bytes: usize) {
203 let mut map = self
204 .tenant_budgets
205 .write()
206 .unwrap_or_else(|p| p.into_inner());
207 map.insert((db, tenant), ScopedBudget::new(max_bytes));
208 }
209
210 pub fn clear_tenant_budget(&self, db: DatabaseId, tenant: TenantId) {
212 let mut map = self
213 .tenant_budgets
214 .write()
215 .unwrap_or_else(|p| p.into_inner());
216 map.remove(&(db, tenant));
217 }
218
219 pub fn try_reserve(
234 &self,
235 db: DatabaseId,
236 tenant: TenantId,
237 engine: EngineId,
238 size: usize,
239 ) -> Result<ReservationToken> {
240 let global_arc = Arc::clone(&self.global_counter);
242 if size > 0 {
243 loop {
244 let current = global_arc.allocated.load(Ordering::Relaxed);
245 if current + size > global_arc.ceiling {
246 return Err(MemError::GlobalCeilingExceeded {
247 allocated: current,
248 ceiling: global_arc.ceiling,
249 requested: size,
250 });
251 }
252 match global_arc.allocated.compare_exchange_weak(
253 current,
254 current + size,
255 Ordering::AcqRel,
256 Ordering::Relaxed,
257 ) {
258 Ok(_) => break,
259 Err(_) => continue,
260 }
261 }
262 }
263
264 let db_counter = {
266 let map = self
267 .database_budgets
268 .read()
269 .unwrap_or_else(|p| p.into_inner());
270 if let Some(budget) = map.get(&db) {
271 match budget.try_reserve(size) {
272 Some(arc) => Some(arc),
273 None => {
274 if size > 0 {
276 global_arc.allocated.fetch_sub(size, Ordering::Relaxed);
277 }
278 return Err(MemError::DatabaseBudgetExhausted {
279 db,
280 requested: size,
281 available: budget.available(),
282 limit: budget.limit,
283 });
284 }
285 }
286 } else {
287 None
288 }
289 };
290
291 let tenant_counter = {
293 let map = self
294 .tenant_budgets
295 .read()
296 .unwrap_or_else(|p| p.into_inner());
297 if let Some(budget) = map.get(&(db, tenant)) {
298 match budget.try_reserve(size) {
299 Some(arc) => Some(arc),
300 None => {
301 if let Some(ref ctr) = db_counter
303 && size > 0
304 {
305 ctr.fetch_sub(size, Ordering::Relaxed);
306 }
307 if size > 0 {
308 global_arc.allocated.fetch_sub(size, Ordering::Relaxed);
309 }
310 return Err(MemError::TenantBudgetExhausted {
311 db,
312 tenant,
313 requested: size,
314 available: budget.available(),
315 limit: budget.limit,
316 });
317 }
318 }
319 } else {
320 None
321 }
322 };
323
324 let engine_budget = self
326 .budgets
327 .get(&engine)
328 .ok_or(MemError::UnknownEngine(engine))?;
329
330 let engine_counter = if let Some(arc) = engine_budget.try_reserve_arc(size) {
331 Some(arc)
332 } else {
333 if let Some(ref ctr) = tenant_counter
335 && size > 0
336 {
337 ctr.fetch_sub(size, Ordering::Relaxed);
338 }
339 if let Some(ref ctr) = db_counter
340 && size > 0
341 {
342 ctr.fetch_sub(size, Ordering::Relaxed);
343 }
344 if size > 0 {
345 global_arc.allocated.fetch_sub(size, Ordering::Relaxed);
346 }
347 return Err(MemError::BudgetExhausted {
348 engine,
349 requested: size,
350 available: engine_budget.available(),
351 limit: engine_budget.limit(),
352 });
353 };
354
355 Ok(ReservationToken::new(
356 crate::reservation_token::ReservationParams {
357 global_counter: global_arc,
358 database_counter: db_counter,
359 tenant_counter,
360 engine_counter,
361 size,
362 db,
363 tenant,
364 engine,
365 },
366 ))
367 }
368
369 pub fn release(&self, engine: EngineId, size: usize) {
376 if let Some(budget) = self.budgets.get(&engine) {
377 budget.release(size);
378 }
379 crate::budget::atomic_saturating_sub(&self.global_counter.allocated, size);
384 }
385
386 pub fn budget(&self, engine: EngineId) -> Option<&Budget> {
388 self.budgets.get(&engine)
389 }
390
391 pub fn global_ceiling(&self) -> usize {
393 self.global_ceiling
394 }
395
396 pub fn total_allocated(&self) -> usize {
398 self.budgets.values().map(|b| b.allocated()).sum()
399 }
400
401 pub fn total_over_release_count(&self) -> usize {
408 self.budgets.values().map(|b| b.over_release_count()).sum()
409 }
410
411 pub fn global_utilization_percent(&self) -> u8 {
414 if self.global_ceiling == 0 {
415 return 100;
416 }
417 ((self.total_allocated() as u128 * 100) / self.global_ceiling as u128).min(100) as u8
418 }
419
420 pub fn engine_pressure(&self, engine: EngineId) -> PressureLevel {
422 self.budgets
423 .get(&engine)
424 .map(|b| self.thresholds.level_for(b.utilization_percent()))
425 .unwrap_or(PressureLevel::Emergency)
426 }
427
428 pub fn global_pressure(&self) -> PressureLevel {
430 self.thresholds.level_for(self.global_utilization_percent())
431 }
432
433 pub fn worst_engine_pressure(&self) -> PressureLevel {
439 self.budgets
440 .values()
441 .map(|b| self.thresholds.level_for(b.utilization_percent()))
442 .max()
443 .unwrap_or(PressureLevel::Normal)
444 }
445
446 pub fn set_thresholds(&mut self, thresholds: PressureThresholds) {
448 self.thresholds = thresholds;
449 }
450
451 pub fn snapshot(&self) -> Vec<EngineSnapshot> {
453 self.budgets
454 .iter()
455 .map(|(engine, budget)| EngineSnapshot {
456 engine: *engine,
457 allocated: budget.allocated(),
458 limit: budget.limit(),
459 peak: budget.peak(),
460 rejections: budget.rejections(),
461 utilization_percent: budget.utilization_percent(),
462 })
463 .collect()
464 }
465}
466
467#[derive(Debug, Clone)]
469pub struct EngineSnapshot {
470 pub engine: EngineId,
471 pub allocated: usize,
472 pub limit: usize,
473 pub peak: usize,
474 pub rejections: usize,
475 pub utilization_percent: u8,
476}
477
478#[cfg(test)]
479mod tests {
480 use std::collections::HashMap;
481 use std::sync::Arc;
482 use std::thread;
483
484 use nodedb_types::{DatabaseId, TenantId};
485
486 use super::*;
487
488 fn test_config() -> GovernorConfig {
489 let mut engine_limits = HashMap::new();
490 engine_limits.insert(EngineId::Vector, 4096);
491 engine_limits.insert(EngineId::Query, 2048);
492 engine_limits.insert(EngineId::Timeseries, 1024);
493
494 GovernorConfig {
495 global_ceiling: 8192,
496 engine_limits,
497 }
498 }
499
500 fn db() -> DatabaseId {
501 DatabaseId::DEFAULT
502 }
503
504 fn tenant() -> TenantId {
505 TenantId::new(1)
506 }
507
508 #[test]
511 fn reserve_within_budget() {
512 let gov = MemoryGovernor::new(test_config()).unwrap();
513 let tok = gov
514 .try_reserve(db(), tenant(), EngineId::Vector, 1000)
515 .unwrap();
516 assert_eq!(gov.budget(EngineId::Vector).unwrap().allocated(), 1000);
517 assert_eq!(tok.size(), 1000);
518 }
519
520 #[test]
521 fn reserve_exceeds_engine_budget() {
522 let gov = MemoryGovernor::new(test_config()).unwrap();
523 let err = gov
524 .try_reserve(db(), tenant(), EngineId::Query, 3000)
525 .unwrap_err();
526 assert!(matches!(err, MemError::BudgetExhausted { .. }));
527 }
528
529 #[test]
530 fn reserve_exceeds_global_ceiling() {
531 let gov = MemoryGovernor::new(test_config()).unwrap();
532 let _t1 = gov
534 .try_reserve(db(), tenant(), EngineId::Vector, 4096)
535 .unwrap();
536 let _t2 = gov
537 .try_reserve(db(), tenant(), EngineId::Query, 2048)
538 .unwrap();
539 let _t3 = gov
540 .try_reserve(db(), tenant(), EngineId::Timeseries, 1024)
541 .unwrap();
542 let err = gov
544 .try_reserve(db(), tenant(), EngineId::Timeseries, 2000)
545 .unwrap_err();
546 assert!(matches!(
547 err,
548 MemError::BudgetExhausted { .. } | MemError::GlobalCeilingExceeded { .. }
549 ));
550 }
551
552 #[test]
555 fn raii_release_returns_to_baseline() {
556 let gov = MemoryGovernor::new(test_config()).unwrap();
557
558 {
559 let tok = gov
560 .try_reserve(db(), tenant(), EngineId::Vector, 1000)
561 .unwrap();
562 assert_eq!(gov.budget(EngineId::Vector).unwrap().allocated(), 1000);
563 assert_eq!(tok.size(), 1000);
564 } assert_eq!(
567 gov.budget(EngineId::Vector).unwrap().allocated(),
568 0,
569 "engine counter must be returned on drop"
570 );
571 }
572
573 #[test]
576 fn database_cap_denies_even_with_tenant_headroom() {
577 let gov = MemoryGovernor::new(test_config()).unwrap();
578 gov.set_database_budget(db(), 500);
580 gov.set_tenant_budget(db(), tenant(), 4096);
582
583 let err = gov
586 .try_reserve(db(), tenant(), EngineId::Vector, 600)
587 .unwrap_err();
588 assert!(
589 matches!(err, MemError::DatabaseBudgetExhausted { .. }),
590 "expected DatabaseBudgetExhausted, got {err:?}"
591 );
592 }
593
594 #[test]
595 fn global_cap_denies_even_with_database_and_tenant_headroom() {
596 let mut engine_limits = HashMap::new();
600 engine_limits.insert(EngineId::Vector, 200);
601 let gov = MemoryGovernor::new(GovernorConfig {
602 global_ceiling: 200,
603 engine_limits,
604 })
605 .unwrap();
606 gov.set_database_budget(db(), 1024);
607 gov.set_tenant_budget(db(), tenant(), 1024);
608
609 let err = gov
610 .try_reserve(db(), tenant(), EngineId::Vector, 300)
611 .unwrap_err();
612 assert!(
613 matches!(err, MemError::GlobalCeilingExceeded { .. }),
614 "expected GlobalCeilingExceeded, got {err:?}"
615 );
616 }
617
618 #[test]
619 fn tenant_cap_denies_with_db_headroom() {
620 let gov = MemoryGovernor::new(test_config()).unwrap();
621 gov.set_database_budget(db(), 4096);
622 gov.set_tenant_budget(db(), tenant(), 300);
623
624 let err = gov
625 .try_reserve(db(), tenant(), EngineId::Vector, 400)
626 .unwrap_err();
627 assert!(
628 matches!(err, MemError::TenantBudgetExhausted { .. }),
629 "expected TenantBudgetExhausted, got {err:?}"
630 );
631 }
632
633 #[test]
636 fn partial_increments_rolled_back_on_db_failure() {
637 let gov = MemoryGovernor::new(test_config()).unwrap();
638 gov.set_database_budget(db(), 50);
639
640 let _ = gov
642 .try_reserve(db(), tenant(), EngineId::Vector, 100)
643 .unwrap_err();
644
645 assert_eq!(
647 gov.global_counter.allocated.load(Ordering::Relaxed),
648 0,
649 "global counter must be rolled back on database-layer failure"
650 );
651 }
652
653 #[test]
654 fn partial_increments_rolled_back_on_tenant_failure() {
655 let gov = MemoryGovernor::new(test_config()).unwrap();
656 gov.set_database_budget(db(), 4096);
657 gov.set_tenant_budget(db(), tenant(), 50);
658
659 let _ = gov
660 .try_reserve(db(), tenant(), EngineId::Vector, 100)
661 .unwrap_err();
662
663 assert_eq!(
665 gov.global_counter.allocated.load(Ordering::Relaxed),
666 0,
667 "global counter must be rolled back on tenant-layer failure"
668 );
669 let db_map = gov.database_budgets.read().unwrap();
670 let db_alloc = db_map[&db()].allocated.load(Ordering::Relaxed);
671 assert_eq!(db_alloc, 0, "database counter must be rolled back");
672 }
673
674 #[test]
677 fn concurrent_reserves_never_exceed_cap() {
678 let mut limits = HashMap::new();
679 limits.insert(EngineId::Vector, 10_000);
680 let gov = Arc::new(
681 MemoryGovernor::new(GovernorConfig {
682 global_ceiling: 10_000,
683 engine_limits: limits,
684 })
685 .unwrap(),
686 );
687 gov.set_database_budget(DatabaseId::DEFAULT, 10_000);
688
689 let n_threads = 8;
691 let reserve_size = 1_000;
692 let mut handles = Vec::new();
693
694 for i in 0..n_threads {
695 let gov_clone = Arc::clone(&gov);
696 handles.push(thread::spawn(move || {
697 gov_clone.try_reserve(
698 DatabaseId::DEFAULT,
699 TenantId::new(i as u64),
700 EngineId::Vector,
701 reserve_size,
702 )
703 }));
704 }
705
706 let results: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
707 let successful: Vec<_> = results.into_iter().filter_map(|r| r.ok()).collect();
708
709 assert!(
711 successful.len() <= 10,
712 "expected at most 10 successful reservations, got {}",
713 successful.len()
714 );
715
716 let engine_alloc = gov.budget(EngineId::Vector).unwrap().allocated();
717 assert!(
718 engine_alloc <= 10_000,
719 "engine total {engine_alloc} must not exceed cap 10000"
720 );
721
722 let global_alloc = gov.global_counter.allocated.load(Ordering::Relaxed);
723 assert!(
724 global_alloc <= 10_000,
725 "global total {global_alloc} must not exceed ceiling 10000"
726 );
727 }
728
729 #[test]
732 fn unknown_engine_rejected() {
733 let gov = MemoryGovernor::new(test_config()).unwrap();
734 let err = gov
735 .try_reserve(db(), tenant(), EngineId::Crdt, 100)
736 .unwrap_err();
737 assert!(matches!(err, MemError::UnknownEngine(EngineId::Crdt)));
738 }
739
740 #[test]
741 fn snapshot_reports_all_engines() {
742 let gov = MemoryGovernor::new(test_config()).unwrap();
743 let _tok = gov
744 .try_reserve(db(), tenant(), EngineId::Vector, 2048)
745 .unwrap();
746
747 let snap = gov.snapshot();
748 assert_eq!(snap.len(), 3);
749
750 let vector_snap = snap.iter().find(|s| s.engine == EngineId::Vector).unwrap();
751 assert_eq!(vector_snap.allocated, 2048);
752 assert_eq!(vector_snap.limit, 4096);
753 assert_eq!(vector_snap.utilization_percent, 50);
754 }
755
756 #[test]
757 fn engine_pressure_levels() {
758 let gov = MemoryGovernor::new(test_config()).unwrap();
759
760 assert_eq!(gov.engine_pressure(EngineId::Vector), PressureLevel::Normal);
761
762 let _tok1 = gov
763 .try_reserve(db(), tenant(), EngineId::Vector, 2868)
764 .unwrap();
765 assert_eq!(
766 gov.engine_pressure(EngineId::Vector),
767 PressureLevel::Warning
768 );
769 }
770
771 #[test]
772 fn worst_engine_pressure_picks_highest() {
773 let gov = MemoryGovernor::new(test_config()).unwrap();
774 assert_eq!(gov.worst_engine_pressure(), PressureLevel::Normal);
775
776 let _tok = gov
779 .try_reserve(db(), tenant(), EngineId::Query, 1800)
780 .unwrap();
781 assert_eq!(gov.engine_pressure(EngineId::Vector), PressureLevel::Normal);
782 assert_eq!(gov.worst_engine_pressure(), PressureLevel::Critical);
783 }
784
785 #[test]
786 fn invalid_config_rejected() {
787 let mut config = test_config();
788 config.global_ceiling = 100;
789 assert!(MemoryGovernor::new(config).is_err());
790 }
791}