1use std::collections::HashMap;
24use std::sync::atomic::{AtomicUsize, Ordering};
25use std::sync::{Arc, RwLock};
26
27use nodedb_types::{DatabaseId, TenantId};
28
29use crate::budget::Budget;
30use crate::engine::EngineId;
31use crate::error::{MemError, Result};
32use crate::pressure::{PressureLevel, PressureThresholds};
33use crate::reservation_token::ReservationToken;
34
35pub struct GlobalCounter {
40 pub(crate) allocated: AtomicUsize,
41 pub(crate) ceiling: usize,
42}
43
44impl std::fmt::Debug for GlobalCounter {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 f.debug_struct("GlobalCounter")
47 .field("allocated", &self.allocated.load(Ordering::Relaxed))
48 .field("ceiling", &self.ceiling)
49 .finish()
50 }
51}
52
53#[derive(Debug)]
57struct ScopedBudget {
58 limit: usize,
59 allocated: Arc<AtomicUsize>,
60}
61
62impl ScopedBudget {
63 fn new(limit: usize) -> Self {
64 Self {
65 limit,
66 allocated: Arc::new(AtomicUsize::new(0)),
67 }
68 }
69
70 fn try_reserve(&self, size: usize) -> Option<Arc<AtomicUsize>> {
73 loop {
74 let current = self.allocated.load(Ordering::Relaxed);
75 if current + size > self.limit {
76 return None;
77 }
78 match self.allocated.compare_exchange_weak(
79 current,
80 current + size,
81 Ordering::AcqRel,
82 Ordering::Relaxed,
83 ) {
84 Ok(_) => return Some(Arc::clone(&self.allocated)),
85 Err(_) => continue,
86 }
87 }
88 }
89
90 fn available(&self) -> usize {
91 let alloc = self.allocated.load(Ordering::Relaxed);
92 self.limit.saturating_sub(alloc)
93 }
94}
95
96#[derive(Debug, Clone)]
98pub struct GovernorConfig {
99 pub global_ceiling: usize,
102
103 pub engine_limits: HashMap<EngineId, usize>,
105}
106
107impl GovernorConfig {
108 pub fn validate(&self) -> Result<()> {
110 let total: usize = self.engine_limits.values().sum();
111 if total > self.global_ceiling {
112 return Err(MemError::GlobalCeilingExceeded {
113 allocated: total,
114 ceiling: self.global_ceiling,
115 requested: 0,
116 });
117 }
118 Ok(())
119 }
120}
121
122#[derive(Debug)]
128pub struct MemoryGovernor {
129 budgets: HashMap<EngineId, Budget>,
131
132 global_counter: Arc<GlobalCounter>,
134
135 global_ceiling: usize,
137
138 thresholds: PressureThresholds,
140
141 database_budgets: RwLock<HashMap<DatabaseId, ScopedBudget>>,
144
145 tenant_budgets: RwLock<HashMap<(DatabaseId, TenantId), ScopedBudget>>,
148}
149
150impl MemoryGovernor {
151 pub fn new(config: GovernorConfig) -> Result<Self> {
153 config.validate()?;
154
155 let mut budgets = HashMap::new();
156 for (engine, limit) in &config.engine_limits {
157 budgets.insert(*engine, Budget::new(*limit));
158 }
159
160 let global_counter = Arc::new(GlobalCounter {
161 allocated: AtomicUsize::new(0),
162 ceiling: config.global_ceiling,
163 });
164
165 Ok(Self {
166 budgets,
167 global_counter,
168 global_ceiling: config.global_ceiling,
169 thresholds: PressureThresholds::default(),
170 database_budgets: RwLock::new(HashMap::new()),
171 tenant_budgets: RwLock::new(HashMap::new()),
172 })
173 }
174
175 pub fn set_database_budget(&self, db: DatabaseId, max_bytes: usize) {
183 let mut map = self
184 .database_budgets
185 .write()
186 .unwrap_or_else(|p| p.into_inner());
187 map.insert(db, ScopedBudget::new(max_bytes));
188 }
189
190 pub fn clear_database_budget(&self, db: DatabaseId) {
192 let mut map = self
193 .database_budgets
194 .write()
195 .unwrap_or_else(|p| p.into_inner());
196 map.remove(&db);
197 }
198
199 pub fn set_tenant_budget(&self, db: DatabaseId, tenant: TenantId, max_bytes: usize) {
203 let mut map = self
204 .tenant_budgets
205 .write()
206 .unwrap_or_else(|p| p.into_inner());
207 map.insert((db, tenant), ScopedBudget::new(max_bytes));
208 }
209
210 pub fn clear_tenant_budget(&self, db: DatabaseId, tenant: TenantId) {
212 let mut map = self
213 .tenant_budgets
214 .write()
215 .unwrap_or_else(|p| p.into_inner());
216 map.remove(&(db, tenant));
217 }
218
219 pub fn try_reserve(
234 &self,
235 db: DatabaseId,
236 tenant: TenantId,
237 engine: EngineId,
238 size: usize,
239 ) -> Result<ReservationToken> {
240 let global_arc = Arc::clone(&self.global_counter);
242 if size > 0 {
243 loop {
244 let current = global_arc.allocated.load(Ordering::Relaxed);
245 if current + size > global_arc.ceiling {
246 return Err(MemError::GlobalCeilingExceeded {
247 allocated: current,
248 ceiling: global_arc.ceiling,
249 requested: size,
250 });
251 }
252 match global_arc.allocated.compare_exchange_weak(
253 current,
254 current + size,
255 Ordering::AcqRel,
256 Ordering::Relaxed,
257 ) {
258 Ok(_) => break,
259 Err(_) => continue,
260 }
261 }
262 }
263
264 let db_counter = {
266 let map = self
267 .database_budgets
268 .read()
269 .unwrap_or_else(|p| p.into_inner());
270 if let Some(budget) = map.get(&db) {
271 match budget.try_reserve(size) {
272 Some(arc) => Some(arc),
273 None => {
274 if size > 0 {
276 global_arc.allocated.fetch_sub(size, Ordering::Relaxed);
277 }
278 return Err(MemError::DatabaseBudgetExhausted {
279 db,
280 requested: size,
281 available: budget.available(),
282 limit: budget.limit,
283 });
284 }
285 }
286 } else {
287 None
288 }
289 };
290
291 let tenant_counter = {
293 let map = self
294 .tenant_budgets
295 .read()
296 .unwrap_or_else(|p| p.into_inner());
297 if let Some(budget) = map.get(&(db, tenant)) {
298 match budget.try_reserve(size) {
299 Some(arc) => Some(arc),
300 None => {
301 if let Some(ref ctr) = db_counter
303 && size > 0
304 {
305 ctr.fetch_sub(size, Ordering::Relaxed);
306 }
307 if size > 0 {
308 global_arc.allocated.fetch_sub(size, Ordering::Relaxed);
309 }
310 return Err(MemError::TenantBudgetExhausted {
311 db,
312 tenant,
313 requested: size,
314 available: budget.available(),
315 limit: budget.limit,
316 });
317 }
318 }
319 } else {
320 None
321 }
322 };
323
324 let engine_budget = self
326 .budgets
327 .get(&engine)
328 .ok_or(MemError::UnknownEngine(engine))?;
329
330 let engine_counter = if let Some(arc) = engine_budget.try_reserve_arc(size) {
331 Some(arc)
332 } else {
333 if let Some(ref ctr) = tenant_counter
335 && size > 0
336 {
337 ctr.fetch_sub(size, Ordering::Relaxed);
338 }
339 if let Some(ref ctr) = db_counter
340 && size > 0
341 {
342 ctr.fetch_sub(size, Ordering::Relaxed);
343 }
344 if size > 0 {
345 global_arc.allocated.fetch_sub(size, Ordering::Relaxed);
346 }
347 return Err(MemError::BudgetExhausted {
348 engine,
349 requested: size,
350 available: engine_budget.available(),
351 limit: engine_budget.limit(),
352 });
353 };
354
355 Ok(ReservationToken::new(
356 crate::reservation_token::ReservationParams {
357 global_counter: global_arc,
358 database_counter: db_counter,
359 tenant_counter,
360 engine_counter,
361 size,
362 db,
363 tenant,
364 engine,
365 },
366 ))
367 }
368
369 pub fn release(&self, engine: EngineId, size: usize) {
376 if let Some(budget) = self.budgets.get(&engine) {
377 budget.release(size);
378 }
379 if size > 0 {
381 self.global_counter
382 .allocated
383 .fetch_sub(size, Ordering::Relaxed);
384 }
385 }
386
387 pub fn budget(&self, engine: EngineId) -> Option<&Budget> {
389 self.budgets.get(&engine)
390 }
391
392 pub fn global_ceiling(&self) -> usize {
394 self.global_ceiling
395 }
396
397 pub fn total_allocated(&self) -> usize {
399 self.budgets.values().map(|b| b.allocated()).sum()
400 }
401
402 pub fn total_over_release_count(&self) -> usize {
409 self.budgets.values().map(|b| b.over_release_count()).sum()
410 }
411
412 pub fn global_utilization_percent(&self) -> u8 {
414 if self.global_ceiling == 0 {
415 return 100;
416 }
417 ((self.total_allocated() * 100) / self.global_ceiling).min(100) as u8
418 }
419
420 pub fn engine_pressure(&self, engine: EngineId) -> PressureLevel {
422 self.budgets
423 .get(&engine)
424 .map(|b| self.thresholds.level_for(b.utilization_percent()))
425 .unwrap_or(PressureLevel::Emergency)
426 }
427
428 pub fn global_pressure(&self) -> PressureLevel {
430 self.thresholds.level_for(self.global_utilization_percent())
431 }
432
433 pub fn set_thresholds(&mut self, thresholds: PressureThresholds) {
435 self.thresholds = thresholds;
436 }
437
438 pub fn snapshot(&self) -> Vec<EngineSnapshot> {
440 self.budgets
441 .iter()
442 .map(|(engine, budget)| EngineSnapshot {
443 engine: *engine,
444 allocated: budget.allocated(),
445 limit: budget.limit(),
446 peak: budget.peak(),
447 rejections: budget.rejections(),
448 utilization_percent: budget.utilization_percent(),
449 })
450 .collect()
451 }
452}
453
454#[derive(Debug, Clone)]
456pub struct EngineSnapshot {
457 pub engine: EngineId,
458 pub allocated: usize,
459 pub limit: usize,
460 pub peak: usize,
461 pub rejections: usize,
462 pub utilization_percent: u8,
463}
464
465#[cfg(test)]
466mod tests {
467 use std::collections::HashMap;
468 use std::sync::Arc;
469 use std::thread;
470
471 use nodedb_types::{DatabaseId, TenantId};
472
473 use super::*;
474
475 fn test_config() -> GovernorConfig {
476 let mut engine_limits = HashMap::new();
477 engine_limits.insert(EngineId::Vector, 4096);
478 engine_limits.insert(EngineId::Query, 2048);
479 engine_limits.insert(EngineId::Timeseries, 1024);
480
481 GovernorConfig {
482 global_ceiling: 8192,
483 engine_limits,
484 }
485 }
486
487 fn db() -> DatabaseId {
488 DatabaseId::DEFAULT
489 }
490
491 fn tenant() -> TenantId {
492 TenantId::new(1)
493 }
494
495 #[test]
498 fn reserve_within_budget() {
499 let gov = MemoryGovernor::new(test_config()).unwrap();
500 let tok = gov
501 .try_reserve(db(), tenant(), EngineId::Vector, 1000)
502 .unwrap();
503 assert_eq!(gov.budget(EngineId::Vector).unwrap().allocated(), 1000);
504 assert_eq!(tok.size(), 1000);
505 }
506
507 #[test]
508 fn reserve_exceeds_engine_budget() {
509 let gov = MemoryGovernor::new(test_config()).unwrap();
510 let err = gov
511 .try_reserve(db(), tenant(), EngineId::Query, 3000)
512 .unwrap_err();
513 assert!(matches!(err, MemError::BudgetExhausted { .. }));
514 }
515
516 #[test]
517 fn reserve_exceeds_global_ceiling() {
518 let gov = MemoryGovernor::new(test_config()).unwrap();
519 let _t1 = gov
521 .try_reserve(db(), tenant(), EngineId::Vector, 4096)
522 .unwrap();
523 let _t2 = gov
524 .try_reserve(db(), tenant(), EngineId::Query, 2048)
525 .unwrap();
526 let _t3 = gov
527 .try_reserve(db(), tenant(), EngineId::Timeseries, 1024)
528 .unwrap();
529 let err = gov
531 .try_reserve(db(), tenant(), EngineId::Timeseries, 2000)
532 .unwrap_err();
533 assert!(matches!(
534 err,
535 MemError::BudgetExhausted { .. } | MemError::GlobalCeilingExceeded { .. }
536 ));
537 }
538
539 #[test]
542 fn raii_release_returns_to_baseline() {
543 let gov = MemoryGovernor::new(test_config()).unwrap();
544
545 {
546 let tok = gov
547 .try_reserve(db(), tenant(), EngineId::Vector, 1000)
548 .unwrap();
549 assert_eq!(gov.budget(EngineId::Vector).unwrap().allocated(), 1000);
550 assert_eq!(tok.size(), 1000);
551 } assert_eq!(
554 gov.budget(EngineId::Vector).unwrap().allocated(),
555 0,
556 "engine counter must be returned on drop"
557 );
558 }
559
560 #[test]
563 fn database_cap_denies_even_with_tenant_headroom() {
564 let gov = MemoryGovernor::new(test_config()).unwrap();
565 gov.set_database_budget(db(), 500);
567 gov.set_tenant_budget(db(), tenant(), 4096);
569
570 let err = gov
573 .try_reserve(db(), tenant(), EngineId::Vector, 600)
574 .unwrap_err();
575 assert!(
576 matches!(err, MemError::DatabaseBudgetExhausted { .. }),
577 "expected DatabaseBudgetExhausted, got {err:?}"
578 );
579 }
580
581 #[test]
582 fn global_cap_denies_even_with_database_and_tenant_headroom() {
583 let mut engine_limits = HashMap::new();
587 engine_limits.insert(EngineId::Vector, 200);
588 let gov = MemoryGovernor::new(GovernorConfig {
589 global_ceiling: 200,
590 engine_limits,
591 })
592 .unwrap();
593 gov.set_database_budget(db(), 1024);
594 gov.set_tenant_budget(db(), tenant(), 1024);
595
596 let err = gov
597 .try_reserve(db(), tenant(), EngineId::Vector, 300)
598 .unwrap_err();
599 assert!(
600 matches!(err, MemError::GlobalCeilingExceeded { .. }),
601 "expected GlobalCeilingExceeded, got {err:?}"
602 );
603 }
604
605 #[test]
606 fn tenant_cap_denies_with_db_headroom() {
607 let gov = MemoryGovernor::new(test_config()).unwrap();
608 gov.set_database_budget(db(), 4096);
609 gov.set_tenant_budget(db(), tenant(), 300);
610
611 let err = gov
612 .try_reserve(db(), tenant(), EngineId::Vector, 400)
613 .unwrap_err();
614 assert!(
615 matches!(err, MemError::TenantBudgetExhausted { .. }),
616 "expected TenantBudgetExhausted, got {err:?}"
617 );
618 }
619
620 #[test]
623 fn partial_increments_rolled_back_on_db_failure() {
624 let gov = MemoryGovernor::new(test_config()).unwrap();
625 gov.set_database_budget(db(), 50);
626
627 let _ = gov
629 .try_reserve(db(), tenant(), EngineId::Vector, 100)
630 .unwrap_err();
631
632 assert_eq!(
634 gov.global_counter.allocated.load(Ordering::Relaxed),
635 0,
636 "global counter must be rolled back on database-layer failure"
637 );
638 }
639
640 #[test]
641 fn partial_increments_rolled_back_on_tenant_failure() {
642 let gov = MemoryGovernor::new(test_config()).unwrap();
643 gov.set_database_budget(db(), 4096);
644 gov.set_tenant_budget(db(), tenant(), 50);
645
646 let _ = gov
647 .try_reserve(db(), tenant(), EngineId::Vector, 100)
648 .unwrap_err();
649
650 assert_eq!(
652 gov.global_counter.allocated.load(Ordering::Relaxed),
653 0,
654 "global counter must be rolled back on tenant-layer failure"
655 );
656 let db_map = gov.database_budgets.read().unwrap();
657 let db_alloc = db_map[&db()].allocated.load(Ordering::Relaxed);
658 assert_eq!(db_alloc, 0, "database counter must be rolled back");
659 }
660
661 #[test]
664 fn concurrent_reserves_never_exceed_cap() {
665 let mut limits = HashMap::new();
666 limits.insert(EngineId::Vector, 10_000);
667 let gov = Arc::new(
668 MemoryGovernor::new(GovernorConfig {
669 global_ceiling: 10_000,
670 engine_limits: limits,
671 })
672 .unwrap(),
673 );
674 gov.set_database_budget(DatabaseId::DEFAULT, 10_000);
675
676 let n_threads = 8;
678 let reserve_size = 1_000;
679 let mut handles = Vec::new();
680
681 for i in 0..n_threads {
682 let gov_clone = Arc::clone(&gov);
683 handles.push(thread::spawn(move || {
684 gov_clone.try_reserve(
685 DatabaseId::DEFAULT,
686 TenantId::new(i as u64),
687 EngineId::Vector,
688 reserve_size,
689 )
690 }));
691 }
692
693 let results: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
694 let successful: Vec<_> = results.into_iter().filter_map(|r| r.ok()).collect();
695
696 assert!(
698 successful.len() <= 10,
699 "expected at most 10 successful reservations, got {}",
700 successful.len()
701 );
702
703 let engine_alloc = gov.budget(EngineId::Vector).unwrap().allocated();
704 assert!(
705 engine_alloc <= 10_000,
706 "engine total {engine_alloc} must not exceed cap 10000"
707 );
708
709 let global_alloc = gov.global_counter.allocated.load(Ordering::Relaxed);
710 assert!(
711 global_alloc <= 10_000,
712 "global total {global_alloc} must not exceed ceiling 10000"
713 );
714 }
715
716 #[test]
719 fn unknown_engine_rejected() {
720 let gov = MemoryGovernor::new(test_config()).unwrap();
721 let err = gov
722 .try_reserve(db(), tenant(), EngineId::Crdt, 100)
723 .unwrap_err();
724 assert!(matches!(err, MemError::UnknownEngine(EngineId::Crdt)));
725 }
726
727 #[test]
728 fn snapshot_reports_all_engines() {
729 let gov = MemoryGovernor::new(test_config()).unwrap();
730 let _tok = gov
731 .try_reserve(db(), tenant(), EngineId::Vector, 2048)
732 .unwrap();
733
734 let snap = gov.snapshot();
735 assert_eq!(snap.len(), 3);
736
737 let vector_snap = snap.iter().find(|s| s.engine == EngineId::Vector).unwrap();
738 assert_eq!(vector_snap.allocated, 2048);
739 assert_eq!(vector_snap.limit, 4096);
740 assert_eq!(vector_snap.utilization_percent, 50);
741 }
742
743 #[test]
744 fn engine_pressure_levels() {
745 let gov = MemoryGovernor::new(test_config()).unwrap();
746
747 assert_eq!(gov.engine_pressure(EngineId::Vector), PressureLevel::Normal);
748
749 let _tok1 = gov
750 .try_reserve(db(), tenant(), EngineId::Vector, 2868)
751 .unwrap();
752 assert_eq!(
753 gov.engine_pressure(EngineId::Vector),
754 PressureLevel::Warning
755 );
756 }
757
758 #[test]
759 fn invalid_config_rejected() {
760 let mut config = test_config();
761 config.global_ceiling = 100;
762 assert!(MemoryGovernor::new(config).is_err());
763 }
764}