use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
#[derive(Debug, Clone, Default)]
pub struct MemoryStats {
pub resident_bytes: u64,
pub peak_resident_bytes: u64,
pub checks_performed: u64,
pub soft_limit_warnings: u64,
pub hard_limit_exceeded: bool,
}
#[derive(Debug, Clone)]
pub struct MemoryGuardConfig {
pub hard_limit_mb: usize,
pub soft_limit_mb: usize,
pub check_interval: usize,
pub aggressive_mode: bool,
pub max_growth_rate_mb_per_sec: f64,
}
impl Default for MemoryGuardConfig {
fn default() -> Self {
Self {
hard_limit_mb: 0, soft_limit_mb: 0, check_interval: 500, aggressive_mode: false,
max_growth_rate_mb_per_sec: 100.0,
}
}
}
impl MemoryGuardConfig {
pub fn with_limit_mb(hard_limit_mb: usize) -> Self {
Self {
hard_limit_mb,
soft_limit_mb: (hard_limit_mb * 80) / 100,
..Default::default()
}
}
pub fn aggressive(mut self) -> Self {
self.aggressive_mode = true;
self.check_interval = 100;
self
}
}
#[derive(Debug, Clone)]
pub struct MemoryLimitExceeded {
pub current_mb: usize,
pub limit_mb: usize,
pub is_soft_limit: bool,
pub message: String,
}
impl std::fmt::Display for MemoryLimitExceeded {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for MemoryLimitExceeded {}
#[derive(Debug)]
pub struct MemoryGuard {
config: MemoryGuardConfig,
operation_counter: AtomicU64,
peak_memory_mb: AtomicUsize,
soft_warnings_count: AtomicU64,
hard_limit_exceeded: AtomicBool,
last_check_time_ns: AtomicU64,
last_check_memory_mb: AtomicUsize,
}
impl MemoryGuard {
pub fn new(config: MemoryGuardConfig) -> Self {
Self {
config,
operation_counter: AtomicU64::new(0),
peak_memory_mb: AtomicUsize::new(0),
soft_warnings_count: AtomicU64::new(0),
hard_limit_exceeded: AtomicBool::new(false),
last_check_time_ns: AtomicU64::new(0),
last_check_memory_mb: AtomicUsize::new(0),
}
}
pub fn default_guard() -> Self {
Self::new(MemoryGuardConfig::default())
}
pub fn with_limit(limit_mb: usize) -> Self {
Self::new(MemoryGuardConfig::with_limit_mb(limit_mb))
}
pub fn shared(config: MemoryGuardConfig) -> Arc<Self> {
Arc::new(Self::new(config))
}
pub fn check(&self) -> Result<(), MemoryLimitExceeded> {
if self.config.hard_limit_mb == 0 {
return Ok(());
}
let count = self.operation_counter.fetch_add(1, Ordering::Relaxed);
let interval = if self.config.aggressive_mode {
self.config.check_interval / 5
} else {
self.config.check_interval
};
if !count.is_multiple_of(interval as u64) {
return Ok(());
}
self.check_now()
}
pub fn check_now(&self) -> Result<(), MemoryLimitExceeded> {
if self.config.hard_limit_mb == 0 {
return Ok(());
}
let current_mb = get_memory_usage_mb().unwrap_or(0);
let mut peak = self.peak_memory_mb.load(Ordering::Relaxed);
while current_mb > peak {
match self.peak_memory_mb.compare_exchange_weak(
peak,
current_mb,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(p) => peak = p,
}
}
let now_ns = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0);
let last_time = self.last_check_time_ns.swap(now_ns, Ordering::Relaxed);
let last_mem = self
.last_check_memory_mb
.swap(current_mb, Ordering::Relaxed);
if last_time > 0 && now_ns > last_time {
let elapsed_sec = (now_ns - last_time) as f64 / 1_000_000_000.0;
if elapsed_sec > 0.0 && current_mb > last_mem {
let growth_rate = (current_mb - last_mem) as f64 / elapsed_sec;
if growth_rate > self.config.max_growth_rate_mb_per_sec {
let _ = growth_rate; }
}
}
if current_mb > self.config.hard_limit_mb {
self.hard_limit_exceeded.store(true, Ordering::Relaxed);
return Err(MemoryLimitExceeded {
current_mb,
limit_mb: self.config.hard_limit_mb,
is_soft_limit: false,
message: format!(
"Memory limit exceeded: using {} MB, hard limit is {} MB. \
Reduce transaction volume or increase memory_limit_mb in config.",
current_mb, self.config.hard_limit_mb
),
});
}
if self.config.soft_limit_mb > 0 && current_mb > self.config.soft_limit_mb {
self.soft_warnings_count.fetch_add(1, Ordering::Relaxed);
}
Ok(())
}
pub fn stats(&self) -> MemoryStats {
let current = get_memory_usage_mb().unwrap_or(0);
MemoryStats {
resident_bytes: (current as u64) * 1024 * 1024,
peak_resident_bytes: (self.peak_memory_mb.load(Ordering::Relaxed) as u64) * 1024 * 1024,
checks_performed: self.operation_counter.load(Ordering::Relaxed),
soft_limit_warnings: self.soft_warnings_count.load(Ordering::Relaxed),
hard_limit_exceeded: self.hard_limit_exceeded.load(Ordering::Relaxed),
}
}
pub fn current_usage_mb(&self) -> usize {
get_memory_usage_mb().unwrap_or(0)
}
pub fn peak_usage_mb(&self) -> usize {
self.peak_memory_mb.load(Ordering::Relaxed)
}
pub fn is_available() -> bool {
get_memory_usage_mb().is_some()
}
pub fn reset_stats(&self) {
self.operation_counter.store(0, Ordering::Relaxed);
self.soft_warnings_count.store(0, Ordering::Relaxed);
self.hard_limit_exceeded.store(false, Ordering::Relaxed);
}
}
impl Default for MemoryGuard {
fn default() -> Self {
Self::default_guard()
}
}
#[cfg(target_os = "linux")]
pub fn get_memory_usage_mb() -> Option<usize> {
use std::fs;
if let Ok(content) = fs::read_to_string("/proc/self/statm") {
let parts: Vec<&str> = content.split_whitespace().collect();
if parts.len() >= 2 {
if let Ok(pages) = parts[1].parse::<usize>() {
let page_size_kb = 4;
return Some((pages * page_size_kb) / 1024);
}
}
}
if let Ok(content) = fs::read_to_string("/proc/self/status") {
for line in content.lines() {
if line.starts_with("VmRSS:") {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() >= 2 {
if let Ok(kb) = parts[1].parse::<usize>() {
return Some(kb / 1024);
}
}
}
}
}
None
}
#[cfg(target_os = "macos")]
pub fn get_memory_usage_mb() -> Option<usize> {
use std::process::Command;
let output = Command::new("ps")
.args(["-o", "rss=", "-p", &std::process::id().to_string()])
.output()
.ok()?;
let rss_kb: usize = String::from_utf8_lossy(&output.stdout)
.trim()
.parse()
.ok()?;
Some(rss_kb / 1024)
}
#[cfg(target_os = "windows")]
pub fn get_memory_usage_mb() -> Option<usize> {
None
}
#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
pub fn get_memory_usage_mb() -> Option<usize> {
None
}
pub fn estimate_memory_mb(num_entries: usize, avg_lines_per_entry: usize) -> usize {
let bytes_per_entry = 500 + (avg_lines_per_entry * 300) + 200;
let total_bytes = num_entries * bytes_per_entry;
let with_overhead = (total_bytes as f64 * 1.5) as usize;
with_overhead.div_ceil(1024 * 1024)
}
pub fn check_sufficient_memory(
planned_entries: usize,
avg_lines: usize,
available_limit_mb: usize,
) -> Result<(), String> {
let estimated = estimate_memory_mb(planned_entries, avg_lines);
if available_limit_mb > 0 && estimated > available_limit_mb {
Err(format!(
"Estimated memory requirement ({} MB) exceeds limit ({} MB). \
Reduce transaction count from {} to approximately {}",
estimated,
available_limit_mb,
planned_entries,
(planned_entries * available_limit_mb) / estimated
))
} else {
Ok(())
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_memory_guard_creation() {
let guard = MemoryGuard::with_limit(1024);
assert_eq!(guard.config.hard_limit_mb, 1024);
assert_eq!(guard.config.soft_limit_mb, 819); }
#[test]
fn test_memory_guard_disabled() {
let guard = MemoryGuard::default_guard();
assert!(guard.check().is_ok());
assert!(guard.check_now().is_ok());
}
#[test]
fn test_memory_estimation() {
let est = estimate_memory_mb(1000, 4);
assert!(est > 0);
assert!(est < 100); }
#[test]
fn test_sufficient_memory_check() {
assert!(check_sufficient_memory(1000, 4, 1024).is_ok());
let result = check_sufficient_memory(1_000_000, 10, 100);
assert!(result.is_err());
}
#[test]
fn test_stats_tracking() {
let guard = MemoryGuard::with_limit(10000);
for _ in 0..1000 {
let _ = guard.check();
}
let stats = guard.stats();
assert!(stats.checks_performed > 0);
}
#[test]
fn test_is_available() {
#[cfg(target_os = "linux")]
assert!(MemoryGuard::is_available());
}
}