use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use super::cgroup;
fn env_parsed<T: std::str::FromStr>(prefix: &str, suffix: &str) -> Option<T> {
std::env::var(format!("{prefix}_{suffix}"))
.ok()
.and_then(|v| v.parse().ok())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MemoryPressure {
Low,
Medium,
High,
}
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub struct MemoryGuardConfig {
#[serde(default)]
pub limit_bytes: u64,
#[serde(default = "default_pressure_threshold")]
pub pressure_threshold: f64,
#[serde(default = "default_cgroup_headroom")]
pub cgroup_headroom: f64,
}
fn default_pressure_threshold() -> f64 {
DEFAULT_PRESSURE_THRESHOLD
}
fn default_cgroup_headroom() -> f64 {
DEFAULT_CGROUP_HEADROOM
}
const DEFAULT_CGROUP_HEADROOM: f64 = 0.85;
const DEFAULT_PRESSURE_THRESHOLD: f64 = 0.80;
impl Default for MemoryGuardConfig {
fn default() -> Self {
Self {
limit_bytes: 0, pressure_threshold: DEFAULT_PRESSURE_THRESHOLD,
cgroup_headroom: DEFAULT_CGROUP_HEADROOM,
}
}
}
impl MemoryGuardConfig {
#[must_use]
pub fn from_cascade() -> Self {
#[cfg(feature = "config")]
{
if let Some(cfg) = crate::config::try_get()
&& let Ok(memory) = cfg.unmarshal_key_registered::<Self>("memory")
{
return memory;
}
}
Self::default()
}
#[must_use]
#[cfg(feature = "config")]
pub fn from_env(prefix: &str) -> Self {
use crate::config::flat_env::flat_env_parsed;
let mut config = Self::default();
if let Some(v) = flat_env_parsed::<u64>(prefix, "MEMORY_LIMIT_BYTES") {
config.limit_bytes = v;
}
if let Some(v) = flat_env_parsed::<f64>(prefix, "MEMORY_PRESSURE_THRESHOLD") {
config.pressure_threshold = v;
}
if let Some(v) = flat_env_parsed::<f64>(prefix, "MEMORY_CGROUP_HEADROOM") {
config.cgroup_headroom = v;
}
config
}
#[must_use]
pub fn from_env_raw(prefix: &str) -> Self {
let mut config = Self::default();
if let Some(v) = env_parsed::<u64>(prefix, "MEMORY_LIMIT_BYTES") {
config.limit_bytes = v;
}
if let Some(v) = env_parsed::<f64>(prefix, "MEMORY_PRESSURE_THRESHOLD") {
config.pressure_threshold = v;
}
if let Some(v) = env_parsed::<f64>(prefix, "MEMORY_CGROUP_HEADROOM") {
config.cgroup_headroom = v;
}
config
}
}
pub struct MemoryGuard {
current_bytes: AtomicU64,
limit_bytes: u64,
pressure_threshold: f64,
under_pressure: AtomicBool,
}
impl MemoryGuard {
#[must_use]
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
pub fn new(config: MemoryGuardConfig) -> Self {
let raw_limit = if config.limit_bytes > 0 {
config.limit_bytes
} else {
let detected = cgroup::detect_memory_limit();
(detected as f64 * config.cgroup_headroom) as u64
};
tracing::info!(
limit_bytes = raw_limit,
pressure_threshold = config.pressure_threshold,
"memory guard initialised"
);
Self {
current_bytes: AtomicU64::new(0),
limit_bytes: raw_limit,
pressure_threshold: config.pressure_threshold,
under_pressure: AtomicBool::new(false),
}
}
#[inline]
pub fn try_reserve(&self, bytes: u64) -> bool {
let current = self.current_bytes.fetch_add(bytes, Ordering::Relaxed) + bytes;
if current > self.limit_bytes {
self.current_bytes.fetch_sub(bytes, Ordering::Relaxed);
self.under_pressure.store(true, Ordering::Relaxed);
return false;
}
self.update_pressure(current);
true
}
#[inline]
pub fn add_bytes(&self, bytes: u64) {
let new_total = self.current_bytes.fetch_add(bytes, Ordering::Relaxed) + bytes;
self.update_pressure(new_total);
}
#[inline]
pub fn release(&self, bytes: u64) {
let prev = self
.current_bytes
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
Some(current.saturating_sub(bytes))
})
.unwrap_or_else(|v| v);
self.update_pressure(prev.saturating_sub(bytes));
}
#[inline]
pub fn under_pressure(&self) -> bool {
self.under_pressure.load(Ordering::Relaxed)
}
#[inline]
pub fn pressure(&self) -> MemoryPressure {
let ratio = self.pressure_ratio();
if ratio >= self.pressure_threshold {
MemoryPressure::High
} else if ratio >= 0.5 {
MemoryPressure::Medium
} else {
MemoryPressure::Low
}
}
#[inline]
pub fn pressure_ratio(&self) -> f64 {
self.current_bytes.load(Ordering::Relaxed) as f64 / self.limit_bytes as f64
}
#[inline]
pub fn current_bytes(&self) -> u64 {
self.current_bytes.load(Ordering::Relaxed)
}
#[inline]
pub fn limit_bytes(&self) -> u64 {
self.limit_bytes
}
#[inline]
fn update_pressure(&self, current: u64) {
let ratio = current as f64 / self.limit_bytes as f64;
self.under_pressure
.store(ratio >= self.pressure_threshold, Ordering::Relaxed);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_guard_default() {
let guard = MemoryGuard::new(MemoryGuardConfig {
limit_bytes: 1_000_000, ..Default::default()
});
assert_eq!(guard.limit_bytes(), 1_000_000);
assert_eq!(guard.current_bytes(), 0);
assert!(!guard.under_pressure());
assert_eq!(guard.pressure(), MemoryPressure::Low);
}
#[test]
fn test_try_reserve_within_limit() {
let guard = MemoryGuard::new(MemoryGuardConfig {
limit_bytes: 1000,
..Default::default()
});
assert!(guard.try_reserve(500));
assert_eq!(guard.current_bytes(), 500);
}
#[test]
fn test_try_reserve_over_limit() {
let guard = MemoryGuard::new(MemoryGuardConfig {
limit_bytes: 1000,
..Default::default()
});
assert!(guard.try_reserve(500));
assert!(!guard.try_reserve(600)); assert_eq!(guard.current_bytes(), 500); assert!(guard.under_pressure());
}
#[test]
fn test_release_reduces_pressure() {
let guard = MemoryGuard::new(MemoryGuardConfig {
limit_bytes: 1000,
pressure_threshold: 0.8,
..Default::default()
});
guard.add_bytes(900); assert!(guard.under_pressure());
assert_eq!(guard.pressure(), MemoryPressure::High);
guard.release(500); assert!(!guard.under_pressure());
assert_eq!(guard.pressure(), MemoryPressure::Low);
}
#[test]
fn test_pressure_levels() {
let guard = MemoryGuard::new(MemoryGuardConfig {
limit_bytes: 1000,
pressure_threshold: 0.8,
..Default::default()
});
guard.add_bytes(400);
assert_eq!(guard.pressure(), MemoryPressure::Low);
guard.add_bytes(200); assert_eq!(guard.pressure(), MemoryPressure::Medium);
guard.add_bytes(300); assert_eq!(guard.pressure(), MemoryPressure::High);
}
#[test]
fn test_pressure_ratio() {
let guard = MemoryGuard::new(MemoryGuardConfig {
limit_bytes: 1000,
..Default::default()
});
guard.add_bytes(250);
let ratio = guard.pressure_ratio();
assert!((ratio - 0.25).abs() < 0.001);
}
#[test]
fn test_release_saturating() {
let guard = MemoryGuard::new(MemoryGuardConfig {
limit_bytes: 1000,
..Default::default()
});
guard.add_bytes(100);
guard.release(200); assert_eq!(
guard.current_bytes(),
0,
"over-release must saturate to 0, not wrap"
);
assert!(!guard.under_pressure());
assert_eq!(guard.pressure(), MemoryPressure::Low);
assert!(guard.try_reserve(500));
assert_eq!(guard.current_bytes(), 500);
}
#[test]
fn test_concurrent_reserve_release() {
use std::sync::Arc;
use std::thread;
let guard = Arc::new(MemoryGuard::new(MemoryGuardConfig {
limit_bytes: 100_000,
pressure_threshold: 0.8,
..Default::default()
}));
let mut handles = vec![];
for _ in 0..10 {
let g = Arc::clone(&guard);
handles.push(thread::spawn(move || {
for _ in 0..100 {
g.add_bytes(100);
g.release(100);
}
}));
}
for h in handles {
h.join().unwrap();
}
assert!(
guard.current_bytes() < 1000,
"leaked bytes: {}",
guard.current_bytes()
);
}
#[test]
fn test_try_reserve_rollback_is_atomic() {
let guard = MemoryGuard::new(MemoryGuardConfig {
limit_bytes: 100,
..Default::default()
});
assert!(guard.try_reserve(90));
assert!(!guard.try_reserve(20)); assert_eq!(guard.current_bytes(), 90); assert!(guard.try_reserve(10)); assert_eq!(guard.current_bytes(), 100);
}
#[test]
fn test_config_defaults() {
let config = MemoryGuardConfig::default();
assert_eq!(config.limit_bytes, 0);
assert!((config.pressure_threshold - 0.80).abs() < 0.001);
assert!((config.cgroup_headroom - 0.85).abs() < 0.001);
}
#[test]
fn test_from_env_raw_defaults_when_unset() {
let config = MemoryGuardConfig::from_env_raw("TEST_MG_UNSET");
assert_eq!(config.limit_bytes, 0);
assert!((config.pressure_threshold - 0.80).abs() < 0.001);
assert!((config.cgroup_headroom - 0.85).abs() < 0.001);
}
#[test]
fn test_env_parsed_helper() {
assert!(env_parsed::<u64>("NONEXISTENT_PREFIX_XYZ", "FOO").is_none());
assert!(env_parsed::<f64>("NONEXISTENT_PREFIX_XYZ", "BAR").is_none());
}
#[test]
fn test_guard_with_explicit_config_overrides() {
let config = MemoryGuardConfig {
limit_bytes: 2_147_483_648,
pressure_threshold: 0.75,
cgroup_headroom: 0.90,
};
let guard = MemoryGuard::new(config);
assert_eq!(guard.limit_bytes(), 2_147_483_648);
}
#[test]
fn test_guard_with_custom_headroom() {
let config = MemoryGuardConfig {
limit_bytes: 0, pressure_threshold: 0.80,
cgroup_headroom: 0.85,
};
let guard = MemoryGuard::new(config);
assert!(guard.limit_bytes() > 0);
}
#[test]
fn test_auto_detect_limit() {
let guard = MemoryGuard::new(MemoryGuardConfig::default());
assert!(
guard.limit_bytes() > 0,
"auto-detected limit should be positive"
);
}
}