use anyhow::{bail, Result};
use tracing::{info, warn};
pub fn current_rss_mb() -> Result<u64> {
#[cfg(target_os = "linux")]
{
read_rss_procfs()
}
#[cfg(not(target_os = "linux"))]
{
read_rss_sysinfo()
}
}
#[cfg(target_os = "linux")]
fn read_rss_procfs() -> Result<u64> {
let status = std::fs::read_to_string("/proc/self/status")?;
for line in status.lines() {
if line.starts_with("VmRSS:") {
let kb: u64 = line
.split_whitespace()
.nth(1)
.ok_or_else(|| anyhow::anyhow!("malformed VmRSS line"))?
.parse()
.map_err(|_| anyhow::anyhow!("non-numeric VmRSS value"))?;
return Ok(kb / 1024); }
}
bail!("VmRSS not found in /proc/self/status")
}
#[cfg(not(target_os = "linux"))]
fn read_rss_sysinfo() -> Result<u64> {
use sysinfo::System;
let mut sys = System::new();
sys.refresh_processes(sysinfo::ProcessesToUpdate::All, true);
let pid = sysinfo::Pid::from(std::process::id() as usize);
if let Some(proc) = sys.process(pid) {
Ok(proc.memory() / (1024 * 1024))
} else {
bail!("Could not read process memory via sysinfo")
}
}
pub fn apply_hard_limit(mb: u64) -> Result<()> {
let hard_mb = mb * 110 / 100; let hard_bytes = hard_mb * 1024 * 1024;
#[cfg(target_os = "linux")]
{
let rlim = libc::rlimit {
rlim_cur: hard_bytes,
rlim_max: libc::RLIM_INFINITY,
};
let result = unsafe { libc::setrlimit(libc::RLIMIT_AS, &rlim) };
if result != 0 {
let err = std::io::Error::last_os_error();
warn!(
"Failed to set RLIMIT_AS to {} MB ({} bytes): {}. Continuing without hard limit.",
hard_mb, hard_bytes, err
);
} else {
info!(
"Set hard RLIMIT_AS ceiling to {} MB (110% of {} MB cap)",
hard_mb, mb
);
}
}
#[cfg(not(target_os = "linux"))]
{
let _ = hard_bytes;
info!(
"Hard RSS limit not supported on this platform; soft monitoring only (cap = {} MB)",
mb
);
}
Ok(())
}
pub struct MemoryCapGuard {
cap_mb: u64,
warn_threshold_mb: u64,
warned: bool,
check_counter: u64,
check_interval: u64,
}
impl MemoryCapGuard {
pub fn new(cap_mb: u64) -> Self {
Self {
cap_mb,
warn_threshold_mb: cap_mb * 90 / 100,
warned: false,
check_counter: 0,
check_interval: 100, }
}
pub fn check(&mut self) -> Result<()> {
self.check_counter += 1;
if self.check_counter % self.check_interval != 0 {
return Ok(());
}
self.check_now()
}
pub fn check_now(&mut self) -> Result<()> {
match current_rss_mb() {
Ok(rss) => {
if rss > self.cap_mb {
bail!(
"Memory cap exceeded: RSS is {} MB, cap is {} MB. \
Indexing stopped gracefully. Increase --max-memory or index a smaller project.",
rss, self.cap_mb
);
}
if rss > self.warn_threshold_mb && !self.warned {
warn!(
"Approaching memory cap: RSS is {} MB ({}% of {} MB cap)",
rss,
rss * 100 / self.cap_mb,
self.cap_mb
);
self.warned = true;
}
Ok(())
}
Err(e) => {
warn!("Could not read RSS for memory cap check: {}", e);
Ok(())
}
}
}
#[allow(dead_code)]
pub fn cap_mb(&self) -> u64 {
self.cap_mb
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_current_rss_is_reasonable() {
let rss = current_rss_mb().expect("should be able to read RSS");
assert!(rss > 0, "RSS should be positive, got {}", rss);
assert!(rss < 10_000, "RSS should be less than 10 GB, got {}", rss);
}
#[test]
fn test_memory_cap_guard_under_cap() {
let mut guard = MemoryCapGuard::new(1_000_000);
guard.check_interval = 1; guard.check().expect("should not error when under cap");
}
#[test]
fn test_memory_cap_guard_throttling() {
let mut guard = MemoryCapGuard::new(1_000_000);
guard.check_interval = 1000;
for _ in 0..999 {
guard.check().expect("should not error");
}
guard.check().expect("should not error");
}
#[test]
fn test_memory_cap_guard_over_cap() {
let mut guard = MemoryCapGuard::new(1);
guard.check_interval = 1;
let result = guard.check();
assert!(result.is_err(), "should error when RSS exceeds cap");
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("Memory cap exceeded"),
"error should mention cap exceeded: {}",
err_msg
);
}
}