use std::sync::OnceLock;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use super::cgroup;
static HEAP_SOURCE: OnceLock<fn() -> usize> = OnceLock::new();
#[must_use]
pub fn set_heap_source(source: fn() -> usize) -> bool {
HEAP_SOURCE.set(source).is_ok()
}
#[inline]
fn heap_bytes() -> Option<u64> {
HEAP_SOURCE.get().map(|f| f() as u64)
}
fn env_parsed<T: std::str::FromStr>(prefix: &str, suffix: &str) -> Option<T> {
std::env::var(format!("{prefix}_{suffix}"))
.ok()
.and_then(|v| v.parse().ok())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MemoryPressure {
Low,
Medium,
High,
}
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub struct MemoryGuardConfig {
#[serde(default)]
pub limit_bytes: u64,
#[serde(default = "default_pressure_threshold")]
pub pressure_threshold: f64,
#[serde(default = "default_cgroup_headroom")]
pub cgroup_headroom: f64,
}
fn default_pressure_threshold() -> f64 {
DEFAULT_PRESSURE_THRESHOLD
}
fn default_cgroup_headroom() -> f64 {
DEFAULT_CGROUP_HEADROOM
}
fn check_fraction(v: f64, name: &str) -> Result<(), String> {
if !v.is_finite() || v <= 0.0 || v > 1.0 {
return Err(format!(
"memory.{name} must be a finite fraction in (0.0, 1.0], got {v}"
));
}
Ok(())
}
fn sane_fraction(v: f64, default: f64, name: &str) -> f64 {
if check_fraction(v, name).is_err() {
tracing::error!(
value = v,
"invalid memory.{name} (need finite fraction in (0,1]); using default {default}"
);
default
} else {
v
}
}
const DEFAULT_CGROUP_HEADROOM: f64 = 0.85;
const DEFAULT_PRESSURE_THRESHOLD: f64 = 0.80;
impl Default for MemoryGuardConfig {
fn default() -> Self {
Self {
limit_bytes: 0, pressure_threshold: DEFAULT_PRESSURE_THRESHOLD,
cgroup_headroom: DEFAULT_CGROUP_HEADROOM,
}
}
}
impl MemoryGuardConfig {
#[must_use]
pub fn from_cascade() -> Self {
#[cfg(feature = "config")]
{
if let Some(cfg) = crate::config::try_get()
&& let Ok(memory) = cfg.unmarshal_key_registered::<Self>("memory")
{
return memory;
}
}
Self::default()
}
#[must_use]
#[cfg(feature = "config")]
pub fn from_env(prefix: &str) -> Self {
use crate::config::flat_env::flat_env_parsed;
let mut config = Self::default();
if let Some(v) = flat_env_parsed::<u64>(prefix, "MEMORY_LIMIT_BYTES") {
config.limit_bytes = v;
}
if let Some(v) = flat_env_parsed::<f64>(prefix, "MEMORY_PRESSURE_THRESHOLD") {
config.pressure_threshold = v;
}
if let Some(v) = flat_env_parsed::<f64>(prefix, "MEMORY_CGROUP_HEADROOM") {
config.cgroup_headroom = v;
}
config
}
#[must_use]
pub fn from_env_raw(prefix: &str) -> Self {
let mut config = Self::default();
if let Some(v) = env_parsed::<u64>(prefix, "MEMORY_LIMIT_BYTES") {
config.limit_bytes = v;
}
if let Some(v) = env_parsed::<f64>(prefix, "MEMORY_PRESSURE_THRESHOLD") {
config.pressure_threshold = v;
}
if let Some(v) = env_parsed::<f64>(prefix, "MEMORY_CGROUP_HEADROOM") {
config.cgroup_headroom = v;
}
config
}
pub fn validate(&self) -> Result<(), String> {
check_fraction(self.pressure_threshold, "pressure_threshold")?;
check_fraction(self.cgroup_headroom, "cgroup_headroom")?;
Ok(())
}
}
pub struct MemoryGuard {
current_bytes: AtomicU64,
limit_bytes: u64,
pressure_threshold: f64,
under_pressure: AtomicBool,
}
impl MemoryGuard {
#[must_use]
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
pub fn new(config: MemoryGuardConfig) -> Self {
let pressure_threshold = sane_fraction(
config.pressure_threshold,
DEFAULT_PRESSURE_THRESHOLD,
"pressure_threshold",
);
let cgroup_headroom = sane_fraction(
config.cgroup_headroom,
DEFAULT_CGROUP_HEADROOM,
"cgroup_headroom",
);
let raw_limit = if config.limit_bytes > 0 {
config.limit_bytes
} else {
let detected = cgroup::detect_memory_limit();
(detected as f64 * cgroup_headroom) as u64
};
let limit_bytes = raw_limit.max(1);
tracing::info!(limit_bytes, pressure_threshold, "memory guard initialised");
Self {
current_bytes: AtomicU64::new(0),
limit_bytes,
pressure_threshold,
under_pressure: AtomicBool::new(false),
}
}
#[inline]
pub fn try_reserve(&self, bytes: u64) -> bool {
if let Some(heap) = heap_bytes() {
return heap + bytes <= self.limit_bytes;
}
let current = self.current_bytes.fetch_add(bytes, Ordering::Relaxed) + bytes;
if current > self.limit_bytes {
self.current_bytes.fetch_sub(bytes, Ordering::Relaxed);
self.under_pressure.store(true, Ordering::Relaxed);
return false;
}
self.update_pressure(current);
true
}
#[inline]
pub fn add_bytes(&self, bytes: u64) {
let new_total = self.current_bytes.fetch_add(bytes, Ordering::Relaxed) + bytes;
self.update_pressure(new_total);
}
#[inline]
pub fn release(&self, bytes: u64) {
let prev = self
.current_bytes
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
Some(current.saturating_sub(bytes))
})
.unwrap_or_else(|v| v);
self.update_pressure(prev.saturating_sub(bytes));
}
#[inline]
pub fn under_pressure(&self) -> bool {
if heap_bytes().is_some() {
return self.pressure_ratio() >= self.pressure_threshold;
}
self.under_pressure.load(Ordering::Relaxed)
}
#[inline]
pub fn pressure(&self) -> MemoryPressure {
let ratio = self.pressure_ratio();
if ratio >= self.pressure_threshold {
MemoryPressure::High
} else if ratio >= 0.5 {
MemoryPressure::Medium
} else {
MemoryPressure::Low
}
}
#[inline]
pub fn pressure_ratio(&self) -> f64 {
self.current_bytes() as f64 / self.limit_bytes as f64
}
#[inline]
pub fn current_bytes(&self) -> u64 {
heap_bytes().unwrap_or_else(|| self.current_bytes.load(Ordering::Relaxed))
}
#[inline]
pub fn limit_bytes(&self) -> u64 {
self.limit_bytes
}
#[inline]
fn update_pressure(&self, current: u64) {
let ratio = current as f64 / self.limit_bytes as f64;
self.under_pressure
.store(ratio >= self.pressure_threshold, Ordering::Relaxed);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_guard_default() {
let guard = MemoryGuard::new(MemoryGuardConfig {
limit_bytes: 1_000_000, ..Default::default()
});
assert_eq!(guard.limit_bytes(), 1_000_000);
assert_eq!(guard.current_bytes(), 0);
assert!(!guard.under_pressure());
assert_eq!(guard.pressure(), MemoryPressure::Low);
}
#[test]
fn test_try_reserve_within_limit() {
let guard = MemoryGuard::new(MemoryGuardConfig {
limit_bytes: 1000,
..Default::default()
});
assert!(guard.try_reserve(500));
assert_eq!(guard.current_bytes(), 500);
}
#[test]
fn test_try_reserve_over_limit() {
let guard = MemoryGuard::new(MemoryGuardConfig {
limit_bytes: 1000,
..Default::default()
});
assert!(guard.try_reserve(500));
assert!(!guard.try_reserve(600)); assert_eq!(guard.current_bytes(), 500); assert!(guard.under_pressure());
}
#[test]
fn test_release_reduces_pressure() {
let guard = MemoryGuard::new(MemoryGuardConfig {
limit_bytes: 1000,
pressure_threshold: 0.8,
..Default::default()
});
guard.add_bytes(900); assert!(guard.under_pressure());
assert_eq!(guard.pressure(), MemoryPressure::High);
guard.release(500); assert!(!guard.under_pressure());
assert_eq!(guard.pressure(), MemoryPressure::Low);
}
#[test]
fn test_pressure_levels() {
let guard = MemoryGuard::new(MemoryGuardConfig {
limit_bytes: 1000,
pressure_threshold: 0.8,
..Default::default()
});
guard.add_bytes(400);
assert_eq!(guard.pressure(), MemoryPressure::Low);
guard.add_bytes(200); assert_eq!(guard.pressure(), MemoryPressure::Medium);
guard.add_bytes(300); assert_eq!(guard.pressure(), MemoryPressure::High);
}
#[test]
fn test_pressure_ratio() {
let guard = MemoryGuard::new(MemoryGuardConfig {
limit_bytes: 1000,
..Default::default()
});
guard.add_bytes(250);
let ratio = guard.pressure_ratio();
assert!((ratio - 0.25).abs() < 0.001);
}
#[test]
fn test_release_saturating() {
let guard = MemoryGuard::new(MemoryGuardConfig {
limit_bytes: 1000,
..Default::default()
});
guard.add_bytes(100);
guard.release(200); assert_eq!(
guard.current_bytes(),
0,
"over-release must saturate to 0, not wrap"
);
assert!(!guard.under_pressure());
assert_eq!(guard.pressure(), MemoryPressure::Low);
assert!(guard.try_reserve(500));
assert_eq!(guard.current_bytes(), 500);
}
#[test]
fn test_concurrent_reserve_release() {
use std::sync::Arc;
use std::thread;
let guard = Arc::new(MemoryGuard::new(MemoryGuardConfig {
limit_bytes: 100_000,
pressure_threshold: 0.8,
..Default::default()
}));
let mut handles = vec![];
for _ in 0..10 {
let g = Arc::clone(&guard);
handles.push(thread::spawn(move || {
for _ in 0..100 {
g.add_bytes(100);
g.release(100);
}
}));
}
for h in handles {
h.join().unwrap();
}
assert!(
guard.current_bytes() < 1000,
"leaked bytes: {}",
guard.current_bytes()
);
}
#[test]
fn test_try_reserve_rollback_is_atomic() {
let guard = MemoryGuard::new(MemoryGuardConfig {
limit_bytes: 100,
..Default::default()
});
assert!(guard.try_reserve(90));
assert!(!guard.try_reserve(20)); assert_eq!(guard.current_bytes(), 90); assert!(guard.try_reserve(10)); assert_eq!(guard.current_bytes(), 100);
}
static TEST_HEAP: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
fn test_heap_source() -> usize {
TEST_HEAP.load(Ordering::Relaxed)
}
#[test]
fn heap_source_overrides_read_path_and_admission() {
assert!(set_heap_source(test_heap_source), "first set wins");
assert!(
!set_heap_source(test_heap_source),
"second set is a no-op (first-wins)"
);
let guard = MemoryGuard::new(MemoryGuardConfig {
limit_bytes: 1_000,
pressure_threshold: 0.8,
..Default::default()
});
TEST_HEAP.store(250, Ordering::Relaxed);
assert_eq!(guard.current_bytes(), 250);
assert!((guard.pressure_ratio() - 0.25).abs() < 0.001);
assert!(!guard.under_pressure());
TEST_HEAP.store(850, Ordering::Relaxed);
assert!(
guard.under_pressure(),
"85% live heap is over the 80% threshold"
);
assert_eq!(guard.pressure(), MemoryPressure::High);
TEST_HEAP.store(900, Ordering::Relaxed);
assert!(guard.try_reserve(100), "900 + 100 == limit, admitted");
assert!(!guard.try_reserve(200), "900 + 200 > limit, rejected");
assert_eq!(
guard.current_bytes(),
900,
"counter untouched by try_reserve"
);
}
#[test]
fn test_config_defaults() {
let config = MemoryGuardConfig::default();
assert_eq!(config.limit_bytes, 0);
assert!((config.pressure_threshold - 0.80).abs() < 0.001);
assert!((config.cgroup_headroom - 0.85).abs() < 0.001);
}
#[test]
fn test_from_env_raw_defaults_when_unset() {
let config = MemoryGuardConfig::from_env_raw("TEST_MG_UNSET");
assert_eq!(config.limit_bytes, 0);
assert!((config.pressure_threshold - 0.80).abs() < 0.001);
assert!((config.cgroup_headroom - 0.85).abs() < 0.001);
}
#[test]
fn test_env_parsed_helper() {
assert!(env_parsed::<u64>("NONEXISTENT_PREFIX_XYZ", "FOO").is_none());
assert!(env_parsed::<f64>("NONEXISTENT_PREFIX_XYZ", "BAR").is_none());
}
#[test]
fn test_guard_with_explicit_config_overrides() {
let config = MemoryGuardConfig {
limit_bytes: 2_147_483_648,
pressure_threshold: 0.75,
cgroup_headroom: 0.90,
};
let guard = MemoryGuard::new(config);
assert_eq!(guard.limit_bytes(), 2_147_483_648);
}
#[test]
fn test_guard_with_custom_headroom() {
let config = MemoryGuardConfig {
limit_bytes: 0, pressure_threshold: 0.80,
cgroup_headroom: 0.85,
};
let guard = MemoryGuard::new(config);
assert!(guard.limit_bytes() > 0);
}
#[test]
fn test_validate_accepts_defaults_and_rejects_bad_fractions() {
assert!(MemoryGuardConfig::default().validate().is_ok());
for bad in [0.0, -0.1, 1.5, f64::NAN, f64::INFINITY] {
let cfg = MemoryGuardConfig {
pressure_threshold: bad,
..Default::default()
};
assert!(
cfg.validate().is_err(),
"pressure_threshold={bad} must be rejected"
);
let cfg = MemoryGuardConfig {
cgroup_headroom: bad,
..Default::default()
};
assert!(
cfg.validate().is_err(),
"cgroup_headroom={bad} must be rejected"
);
}
}
#[test]
fn test_new_clamps_invalid_config_no_divide_by_zero() {
let guard = MemoryGuard::new(MemoryGuardConfig {
limit_bytes: 0,
pressure_threshold: 0.0,
cgroup_headroom: 0.0,
});
assert!(guard.limit_bytes() >= 1, "limit floored at >=1");
guard.add_bytes(10);
assert!(
guard.pressure_ratio().is_finite(),
"pressure ratio must be finite, not div-by-zero"
);
}
#[test]
fn test_new_with_nan_threshold_is_finite() {
let guard = MemoryGuard::new(MemoryGuardConfig {
limit_bytes: 1000,
pressure_threshold: f64::NAN,
cgroup_headroom: f64::NAN,
});
assert_eq!(guard.limit_bytes(), 1000);
guard.add_bytes(900);
assert!(guard.under_pressure());
}
#[test]
fn test_auto_detect_limit() {
let guard = MemoryGuard::new(MemoryGuardConfig::default());
assert!(
guard.limit_bytes() > 0,
"auto-detected limit should be positive"
);
}
}