1use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
12
13use super::cgroup;
14
15fn env_parsed<T: std::str::FromStr>(prefix: &str, suffix: &str) -> Option<T> {
17 std::env::var(format!("{prefix}_{suffix}"))
18 .ok()
19 .and_then(|v| v.parse().ok())
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum MemoryPressure {
25 Low,
27 Medium,
29 High,
31}
32
33#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
45pub struct MemoryGuardConfig {
46 #[serde(default)]
48 pub limit_bytes: u64,
49 #[serde(default = "default_pressure_threshold")]
51 pub pressure_threshold: f64,
52 #[serde(default = "default_cgroup_headroom")]
55 pub cgroup_headroom: f64,
56}
57
58fn default_pressure_threshold() -> f64 {
59 DEFAULT_PRESSURE_THRESHOLD
60}
61
62fn default_cgroup_headroom() -> f64 {
63 DEFAULT_CGROUP_HEADROOM
64}
65
66const DEFAULT_CGROUP_HEADROOM: f64 = 0.85;
71
72const DEFAULT_PRESSURE_THRESHOLD: f64 = 0.80;
77
78impl Default for MemoryGuardConfig {
79 fn default() -> Self {
80 Self {
81 limit_bytes: 0, pressure_threshold: DEFAULT_PRESSURE_THRESHOLD,
83 cgroup_headroom: DEFAULT_CGROUP_HEADROOM,
84 }
85 }
86}
87
88impl MemoryGuardConfig {
89 #[must_use]
95 pub fn from_cascade() -> Self {
96 #[cfg(feature = "config")]
97 {
98 if let Some(cfg) = crate::config::try_get()
99 && let Ok(memory) = cfg.unmarshal_key_registered::<Self>("memory")
100 {
101 return memory;
102 }
103 }
104 Self::default()
105 }
106
107 #[must_use]
127 #[cfg(feature = "config")]
128 pub fn from_env(prefix: &str) -> Self {
129 use crate::config::flat_env::flat_env_parsed;
130
131 let mut config = Self::default();
132
133 if let Some(v) = flat_env_parsed::<u64>(prefix, "MEMORY_LIMIT_BYTES") {
134 config.limit_bytes = v;
135 }
136 if let Some(v) = flat_env_parsed::<f64>(prefix, "MEMORY_PRESSURE_THRESHOLD") {
137 config.pressure_threshold = v;
138 }
139 if let Some(v) = flat_env_parsed::<f64>(prefix, "MEMORY_CGROUP_HEADROOM") {
140 config.cgroup_headroom = v;
141 }
142
143 config
144 }
145
146 #[must_use]
150 pub fn from_env_raw(prefix: &str) -> Self {
151 let mut config = Self::default();
152
153 if let Some(v) = env_parsed::<u64>(prefix, "MEMORY_LIMIT_BYTES") {
154 config.limit_bytes = v;
155 }
156 if let Some(v) = env_parsed::<f64>(prefix, "MEMORY_PRESSURE_THRESHOLD") {
157 config.pressure_threshold = v;
158 }
159 if let Some(v) = env_parsed::<f64>(prefix, "MEMORY_CGROUP_HEADROOM") {
160 config.cgroup_headroom = v;
161 }
162
163 config
164 }
165}
166
167pub struct MemoryGuard {
196 current_bytes: AtomicU64,
198 limit_bytes: u64,
200 pressure_threshold: f64,
202 under_pressure: AtomicBool,
204}
205
206impl MemoryGuard {
207 #[must_use]
212 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
213 pub fn new(config: MemoryGuardConfig) -> Self {
214 let raw_limit = if config.limit_bytes > 0 {
215 config.limit_bytes
216 } else {
217 let detected = cgroup::detect_memory_limit();
218 (detected as f64 * config.cgroup_headroom) as u64
220 };
221
222 tracing::info!(
223 limit_bytes = raw_limit,
224 pressure_threshold = config.pressure_threshold,
225 "memory guard initialised"
226 );
227
228 Self {
229 current_bytes: AtomicU64::new(0),
230 limit_bytes: raw_limit,
231 pressure_threshold: config.pressure_threshold,
232 under_pressure: AtomicBool::new(false),
233 }
234 }
235
236 #[inline]
241 pub fn try_reserve(&self, bytes: u64) -> bool {
242 let current = self.current_bytes.fetch_add(bytes, Ordering::Relaxed) + bytes;
243 if current > self.limit_bytes {
244 self.current_bytes.fetch_sub(bytes, Ordering::Relaxed);
246 self.under_pressure.store(true, Ordering::Relaxed);
247 return false;
248 }
249 self.update_pressure(current);
250 true
251 }
252
253 #[inline]
256 pub fn add_bytes(&self, bytes: u64) {
257 let new_total = self.current_bytes.fetch_add(bytes, Ordering::Relaxed) + bytes;
258 self.update_pressure(new_total);
259 }
260
261 #[inline]
265 pub fn release(&self, bytes: u64) {
266 let prev = self
267 .current_bytes
268 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
269 Some(current.saturating_sub(bytes))
270 })
271 .unwrap_or_else(|v| v);
273 self.update_pressure(prev.saturating_sub(bytes));
274 }
275
276 #[inline]
278 pub fn under_pressure(&self) -> bool {
279 self.under_pressure.load(Ordering::Relaxed)
280 }
281
282 #[inline]
284 pub fn pressure(&self) -> MemoryPressure {
285 let ratio = self.pressure_ratio();
286 if ratio >= self.pressure_threshold {
287 MemoryPressure::High
288 } else if ratio >= 0.5 {
289 MemoryPressure::Medium
290 } else {
291 MemoryPressure::Low
292 }
293 }
294
295 #[inline]
297 pub fn pressure_ratio(&self) -> f64 {
298 self.current_bytes.load(Ordering::Relaxed) as f64 / self.limit_bytes as f64
299 }
300
301 #[inline]
303 pub fn current_bytes(&self) -> u64 {
304 self.current_bytes.load(Ordering::Relaxed)
305 }
306
307 #[inline]
309 pub fn limit_bytes(&self) -> u64 {
310 self.limit_bytes
311 }
312
313 #[inline]
315 fn update_pressure(&self, current: u64) {
316 let ratio = current as f64 / self.limit_bytes as f64;
317 self.under_pressure
318 .store(ratio >= self.pressure_threshold, Ordering::Relaxed);
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 #[test]
327 fn test_memory_guard_default() {
328 let guard = MemoryGuard::new(MemoryGuardConfig {
329 limit_bytes: 1_000_000, ..Default::default()
331 });
332 assert_eq!(guard.limit_bytes(), 1_000_000);
333 assert_eq!(guard.current_bytes(), 0);
334 assert!(!guard.under_pressure());
335 assert_eq!(guard.pressure(), MemoryPressure::Low);
336 }
337
338 #[test]
339 fn test_try_reserve_within_limit() {
340 let guard = MemoryGuard::new(MemoryGuardConfig {
341 limit_bytes: 1000,
342 ..Default::default()
343 });
344 assert!(guard.try_reserve(500));
345 assert_eq!(guard.current_bytes(), 500);
346 }
347
348 #[test]
349 fn test_try_reserve_over_limit() {
350 let guard = MemoryGuard::new(MemoryGuardConfig {
351 limit_bytes: 1000,
352 ..Default::default()
353 });
354 assert!(guard.try_reserve(500));
355 assert!(!guard.try_reserve(600)); assert_eq!(guard.current_bytes(), 500); assert!(guard.under_pressure());
358 }
359
360 #[test]
361 fn test_release_reduces_pressure() {
362 let guard = MemoryGuard::new(MemoryGuardConfig {
363 limit_bytes: 1000,
364 pressure_threshold: 0.8,
365 ..Default::default()
366 });
367 guard.add_bytes(900); assert!(guard.under_pressure());
369 assert_eq!(guard.pressure(), MemoryPressure::High);
370
371 guard.release(500); assert!(!guard.under_pressure());
373 assert_eq!(guard.pressure(), MemoryPressure::Low);
374 }
375
376 #[test]
377 fn test_pressure_levels() {
378 let guard = MemoryGuard::new(MemoryGuardConfig {
379 limit_bytes: 1000,
380 pressure_threshold: 0.8,
381 ..Default::default()
382 });
383
384 guard.add_bytes(400);
386 assert_eq!(guard.pressure(), MemoryPressure::Low);
387
388 guard.add_bytes(200); assert_eq!(guard.pressure(), MemoryPressure::Medium);
391
392 guard.add_bytes(300); assert_eq!(guard.pressure(), MemoryPressure::High);
395 }
396
397 #[test]
398 fn test_pressure_ratio() {
399 let guard = MemoryGuard::new(MemoryGuardConfig {
400 limit_bytes: 1000,
401 ..Default::default()
402 });
403 guard.add_bytes(250);
404 let ratio = guard.pressure_ratio();
405 assert!((ratio - 0.25).abs() < 0.001);
406 }
407
408 #[test]
409 fn test_release_saturating() {
410 let guard = MemoryGuard::new(MemoryGuardConfig {
411 limit_bytes: 1000,
412 ..Default::default()
413 });
414 guard.add_bytes(100);
415 guard.release(200); assert_eq!(
417 guard.current_bytes(),
418 0,
419 "over-release must saturate to 0, not wrap"
420 );
421 assert!(!guard.under_pressure());
422 assert_eq!(guard.pressure(), MemoryPressure::Low);
423
424 assert!(guard.try_reserve(500));
426 assert_eq!(guard.current_bytes(), 500);
427 }
428
429 #[test]
430 fn test_concurrent_reserve_release() {
431 use std::sync::Arc;
432 use std::thread;
433
434 let guard = Arc::new(MemoryGuard::new(MemoryGuardConfig {
435 limit_bytes: 100_000,
436 pressure_threshold: 0.8,
437 ..Default::default()
438 }));
439
440 let mut handles = vec![];
441 for _ in 0..10 {
442 let g = Arc::clone(&guard);
443 handles.push(thread::spawn(move || {
444 for _ in 0..100 {
445 g.add_bytes(100);
446 g.release(100);
447 }
448 }));
449 }
450 for h in handles {
451 h.join().unwrap();
452 }
453 assert!(
456 guard.current_bytes() < 1000,
457 "leaked bytes: {}",
458 guard.current_bytes()
459 );
460 }
461
462 #[test]
463 fn test_try_reserve_rollback_is_atomic() {
464 let guard = MemoryGuard::new(MemoryGuardConfig {
465 limit_bytes: 100,
466 ..Default::default()
467 });
468 assert!(guard.try_reserve(90));
469 assert!(!guard.try_reserve(20)); assert_eq!(guard.current_bytes(), 90); assert!(guard.try_reserve(10)); assert_eq!(guard.current_bytes(), 100);
473 }
474
475 #[test]
476 fn test_config_defaults() {
477 let config = MemoryGuardConfig::default();
478 assert_eq!(config.limit_bytes, 0);
479 assert!((config.pressure_threshold - 0.80).abs() < 0.001);
480 assert!((config.cgroup_headroom - 0.85).abs() < 0.001);
481 }
482
483 #[test]
484 fn test_from_env_raw_defaults_when_unset() {
485 let config = MemoryGuardConfig::from_env_raw("TEST_MG_UNSET");
487 assert_eq!(config.limit_bytes, 0);
488 assert!((config.pressure_threshold - 0.80).abs() < 0.001);
489 assert!((config.cgroup_headroom - 0.85).abs() < 0.001);
490 }
491
492 #[test]
493 fn test_env_parsed_helper() {
494 assert!(env_parsed::<u64>("NONEXISTENT_PREFIX_XYZ", "FOO").is_none());
496 assert!(env_parsed::<f64>("NONEXISTENT_PREFIX_XYZ", "BAR").is_none());
497 }
498
499 #[test]
500 fn test_guard_with_explicit_config_overrides() {
501 let config = MemoryGuardConfig {
503 limit_bytes: 2_147_483_648,
504 pressure_threshold: 0.75,
505 cgroup_headroom: 0.90,
506 };
507 let guard = MemoryGuard::new(config);
508 assert_eq!(guard.limit_bytes(), 2_147_483_648);
509 }
510
511 #[test]
512 fn test_guard_with_custom_headroom() {
513 let config = MemoryGuardConfig {
515 limit_bytes: 0, pressure_threshold: 0.80,
517 cgroup_headroom: 0.85,
518 };
519 let guard = MemoryGuard::new(config);
520 assert!(guard.limit_bytes() > 0);
522 }
523
524 #[test]
525 fn test_auto_detect_limit() {
526 let guard = MemoryGuard::new(MemoryGuardConfig::default());
528 assert!(
529 guard.limit_bytes() > 0,
530 "auto-detected limit should be positive"
531 );
532 }
534}