Skip to main content

lean_ctx/core/
memory_guard.rs

1//! Process-level RAM guardian with adaptive eviction and hard OOM protection.
2//!
3//! Monitors RSS via platform-specific APIs and triggers tiered cache eviction
4//! when memory usage exceeds configurable thresholds (default: 5% of system RAM).
5//! At critical levels (>3x limit), performs emergency shutdown to prevent OS OOM kill.
6
7use std::sync::atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering};
8use std::sync::Arc;
9
10static PEAK_RSS: AtomicU64 = AtomicU64::new(0);
11static GUARD_RUNNING: AtomicBool = AtomicBool::new(false);
12static ABORT_REQUESTED: AtomicBool = AtomicBool::new(false);
13static CURRENT_PRESSURE: AtomicU8 = AtomicU8::new(0);
14
15/// Current process RSS in bytes, or `None` if unavailable.
16pub fn get_rss_bytes() -> Option<u64> {
17    #[cfg(target_os = "linux")]
18    {
19        linux_rss()
20    }
21    #[cfg(target_os = "macos")]
22    {
23        macos_rss()
24    }
25    #[cfg(not(any(target_os = "linux", target_os = "macos")))]
26    {
27        None
28    }
29}
30
31/// RSS of an arbitrary process by PID, or `None` if unavailable/dead.
32pub fn get_rss_bytes_for_pid(pid: u32) -> Option<u64> {
33    #[cfg(target_os = "linux")]
34    {
35        linux_rss_for_pid(pid)
36    }
37    #[cfg(target_os = "macos")]
38    {
39        macos_rss_for_pid(pid)
40    }
41    #[cfg(not(any(target_os = "linux", target_os = "macos")))]
42    {
43        let _ = pid;
44        None
45    }
46}
47
48/// Total physical RAM in bytes, or `None` if unavailable.
49pub fn get_system_ram_bytes() -> Option<u64> {
50    #[cfg(target_os = "linux")]
51    {
52        linux_memtotal()
53    }
54    #[cfg(target_os = "macos")]
55    {
56        macos_memsize()
57    }
58    #[cfg(not(any(target_os = "linux", target_os = "macos")))]
59    {
60        None
61    }
62}
63
64/// Returns the RSS limit in bytes based on `max_ram_percent` config.
65pub fn rss_limit_bytes() -> Option<u64> {
66    let sys_ram = get_system_ram_bytes()?;
67    let cfg = super::config::Config::load();
68    let pct = super::config::MemoryGuardConfig::effective(&cfg).max_ram_percent;
69    Some(sys_ram / 100 * u64::from(pct))
70}
71
72/// Recorded peak RSS since process start.
73pub fn peak_rss_bytes() -> u64 {
74    PEAK_RSS.load(Ordering::Relaxed)
75}
76
77/// Snapshot of current memory state for diagnostics.
78#[derive(Debug, Clone, serde::Serialize)]
79pub struct MemorySnapshot {
80    pub rss_bytes: u64,
81    pub peak_rss_bytes: u64,
82    pub system_ram_bytes: u64,
83    pub rss_limit_bytes: u64,
84    pub rss_percent: f64,
85    pub pressure_level: PressureLevel,
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, serde::Serialize)]
89#[serde(rename_all = "lowercase")]
90#[repr(u8)]
91pub enum PressureLevel {
92    Normal = 0,
93    Soft = 1,
94    Medium = 2,
95    Hard = 3,
96    Critical = 4,
97}
98
99impl PressureLevel {
100    fn from_u8(v: u8) -> Self {
101        match v {
102            1 => Self::Soft,
103            2 => Self::Medium,
104            3 => Self::Hard,
105            4 => Self::Critical,
106            _ => Self::Normal,
107        }
108    }
109}
110
111impl MemorySnapshot {
112    /// Capture memory snapshot of the **current** process.
113    pub fn capture() -> Option<Self> {
114        Self::capture_impl(get_rss_bytes()?)
115    }
116
117    /// Capture memory snapshot for the **daemon** process (by PID).
118    /// Falls back to the current process if the PID is dead or unreadable.
119    pub fn capture_for_pid(pid: u32) -> Option<Self> {
120        let rss = get_rss_bytes_for_pid(pid).or_else(get_rss_bytes)?;
121        Self::capture_impl(rss)
122    }
123
124    fn capture_impl(rss: u64) -> Option<Self> {
125        let sys = get_system_ram_bytes()?;
126        let limit = rss_limit_bytes()?;
127        let pct = if sys > 0 {
128            (rss as f64 / sys as f64) * 100.0
129        } else {
130            0.0
131        };
132
133        PEAK_RSS.fetch_max(rss, Ordering::Relaxed);
134
135        let cfg = super::config::Config::load();
136        let guard_cfg = super::config::MemoryGuardConfig::effective(&cfg);
137        let base = f64::from(guard_cfg.max_ram_percent);
138
139        let level = if pct > base * 3.0 {
140            PressureLevel::Critical
141        } else if pct > base * 2.0 {
142            PressureLevel::Hard
143        } else if pct > base * 1.4 {
144            PressureLevel::Medium
145        } else if pct > base {
146            PressureLevel::Soft
147        } else {
148            PressureLevel::Normal
149        };
150
151        Some(Self {
152            rss_bytes: rss,
153            peak_rss_bytes: PEAK_RSS.load(Ordering::Relaxed),
154            system_ram_bytes: sys,
155            rss_limit_bytes: limit,
156            rss_percent: pct,
157            pressure_level: level,
158        })
159    }
160}
161
162/// Force-purge all jemalloc arenas to return memory to the OS.
163/// Uses `MALLCTL_ARENAS_ALL` (value 4096) which is the jemalloc sentinel
164/// for "all arenas". Logs errors instead of silently swallowing them.
165pub fn jemalloc_purge() {
166    #[cfg(all(feature = "jemalloc", not(windows)))]
167    {
168        use tikv_jemalloc_ctl::raw;
169        let purge_mib = b"arena.4096.purge\0";
170        // SAFETY: `purge_mib` is a static, NUL-terminated jemalloc MIB name and
171        // the value type (`u64`) matches the `arena.<i>.purge` ctl; `raw::write`
172        // validates the name and surfaces errors via `Result`.
173        unsafe {
174            if let Err(e) = raw::write(purge_mib, 0u64) {
175                tracing::debug!("[memory_guard] jemalloc purge failed: {e}");
176            }
177        }
178    }
179}
180
181/// Returns `true` if the guardian has requested background tasks to abort.
182pub fn abort_requested() -> bool {
183    ABORT_REQUESTED.load(Ordering::Relaxed)
184}
185
186/// Quick, non-allocating memory pressure check for hot loops (scanners, indexers).
187/// Reads the cached atomic flag set by the guardian thread — O(1), no syscalls.
188pub fn is_under_pressure() -> bool {
189    current_pressure() >= PressureLevel::Soft
190}
191
192/// Returns the current pressure level as last observed by the guardian thread.
193pub fn current_pressure() -> PressureLevel {
194    PressureLevel::from_u8(CURRENT_PRESSURE.load(Ordering::Relaxed))
195}
196
197/// Start the background memory guardian task (idempotent).
198/// Polls every 3s (normal) or 1s (under pressure). At Critical level, performs
199/// emergency shutdown to prevent OS OOM kill.
200pub fn start_guard(eviction_callback: Arc<dyn Fn(PressureLevel) + Send + Sync>) {
201    if GUARD_RUNNING.swap(true, Ordering::SeqCst) {
202        return;
203    }
204    std::thread::Builder::new()
205        .name("memory-guard".into())
206        .spawn(move || {
207            let mut poll_secs = 3u64;
208            loop {
209                std::thread::sleep(std::time::Duration::from_secs(poll_secs));
210                let Some(snap) = MemorySnapshot::capture() else {
211                    continue;
212                };
213
214                CURRENT_PRESSURE.store(snap.pressure_level as u8, Ordering::Relaxed);
215
216                if snap.pressure_level == PressureLevel::Critical {
217                    tracing::error!(
218                        "[memory_guard] CRITICAL: RSS={:.0}MB ({:.1}% of {:.0}GB) — \
219                         aggressive eviction to prevent OS OOM kill",
220                        snap.rss_bytes as f64 / 1_048_576.0,
221                        snap.rss_percent,
222                        snap.system_ram_bytes as f64 / 1_073_741_824.0,
223                    );
224                    ABORT_REQUESTED.store(true, Ordering::SeqCst);
225                    (eviction_callback)(PressureLevel::Critical);
226                    jemalloc_purge();
227
228                    for attempt in 1..=3 {
229                        std::thread::sleep(std::time::Duration::from_secs(2));
230                        (eviction_callback)(PressureLevel::Critical);
231                        jemalloc_purge();
232                        if let Some(recheck) = MemorySnapshot::capture() {
233                            if recheck.pressure_level < PressureLevel::Hard {
234                                tracing::info!(
235                                    "[memory_guard] eviction attempt {attempt} succeeded — \
236                                     RSS={:.0}MB, pressure={:?}",
237                                    recheck.rss_bytes as f64 / 1_048_576.0,
238                                    recheck.pressure_level,
239                                );
240                                break;
241                            }
242                            tracing::error!(
243                                "[memory_guard] eviction attempt {attempt}/3 — still {:?} \
244                                 (RSS={:.0}MB)",
245                                recheck.pressure_level,
246                                recheck.rss_bytes as f64 / 1_048_576.0,
247                            );
248                        }
249                    }
250                }
251
252                if snap.pressure_level >= PressureLevel::Soft {
253                    poll_secs = 1;
254                    ABORT_REQUESTED
255                        .store(snap.pressure_level >= PressureLevel::Hard, Ordering::SeqCst);
256                    tracing::warn!(
257                        "[memory_guard] pressure={:?} RSS={:.0}MB limit={:.0}MB ({:.1}% of {:.0}GB)",
258                        snap.pressure_level,
259                        snap.rss_bytes as f64 / 1_048_576.0,
260                        snap.rss_limit_bytes as f64 / 1_048_576.0,
261                        snap.rss_percent,
262                        snap.system_ram_bytes as f64 / 1_073_741_824.0,
263                    );
264                    (eviction_callback)(snap.pressure_level);
265
266                    if snap.pressure_level >= PressureLevel::Hard {
267                        jemalloc_purge();
268                    }
269                } else {
270                    poll_secs = 3;
271                    if ABORT_REQUESTED.load(Ordering::Relaxed) {
272                        ABORT_REQUESTED.store(false, Ordering::SeqCst);
273                        tracing::info!("[memory_guard] pressure normalized, clearing abort flag");
274                    }
275                }
276            }
277        })
278        .ok();
279}
280
281/// Force immediate purge of all caches and jemalloc arenas.
282pub fn force_purge() {
283    jemalloc_purge();
284    tracing::info!("[memory_guard] force_purge completed");
285}
286
287// --- Platform-specific implementations ---
288
289#[cfg(target_os = "linux")]
290fn linux_rss() -> Option<u64> {
291    linux_rss_for_pid(std::process::id())
292}
293
294#[cfg(target_os = "linux")]
295fn linux_rss_for_pid(pid: u32) -> Option<u64> {
296    let path = format!("/proc/{pid}/status");
297    let status = std::fs::read_to_string(path).ok()?;
298    for line in status.lines() {
299        if let Some(val) = line.strip_prefix("VmRSS:") {
300            let kb: u64 = val.trim().trim_end_matches(" kB").trim().parse().ok()?;
301            return Some(kb * 1024);
302        }
303    }
304    None
305}
306
307#[cfg(target_os = "linux")]
308fn linux_memtotal() -> Option<u64> {
309    let info = std::fs::read_to_string("/proc/meminfo").ok()?;
310    for line in info.lines() {
311        if let Some(val) = line.strip_prefix("MemTotal:") {
312            let kb: u64 = val.trim().trim_end_matches(" kB").trim().parse().ok()?;
313            return Some(kb * 1024);
314        }
315    }
316    None
317}
318
319#[cfg(target_os = "macos")]
320#[allow(deprecated, clippy::borrow_as_ptr, clippy::ptr_as_ptr)]
321fn macos_rss() -> Option<u64> {
322    use std::mem;
323    // SAFETY: `mach_task_basic_info_data_t` is a plain C struct for which an
324    // all-zero bit pattern is a valid initial value.
325    let mut info: libc::mach_task_basic_info_data_t = unsafe { mem::zeroed() };
326    let mut count = (mem::size_of::<libc::mach_task_basic_info_data_t>()
327        / mem::size_of::<libc::natural_t>()) as libc::mach_msg_type_number_t;
328    // SAFETY: `mach_task_self()` returns the current task port; `info` and
329    // `count` are live stack locals passed as out-pointers, sized to match the
330    // requested `MACH_TASK_BASIC_INFO` flavour.
331    let kr = unsafe {
332        libc::task_info(
333            libc::mach_task_self(),
334            libc::MACH_TASK_BASIC_INFO,
335            std::ptr::from_mut(&mut info).cast::<i32>(),
336            std::ptr::from_mut(&mut count),
337        )
338    };
339    if kr == libc::KERN_SUCCESS {
340        Some(info.resident_size)
341    } else {
342        None
343    }
344}
345
346#[cfg(target_os = "macos")]
347fn macos_rss_for_pid(pid: u32) -> Option<u64> {
348    // Use `ps -o rss= -p <pid>` as a portable fallback.
349    // `task_for_pid` requires root/entitlements, `proc_pid_rusage` is private API.
350    let output = std::process::Command::new("ps")
351        .args(["-o", "rss=", "-p", &pid.to_string()])
352        .output()
353        .ok()?;
354    if !output.status.success() {
355        return None;
356    }
357    let text = String::from_utf8_lossy(&output.stdout);
358    let kb: u64 = text.trim().parse().ok()?;
359    Some(kb * 1024)
360}
361
362#[cfg(target_os = "macos")]
363#[allow(clippy::borrow_as_ptr, clippy::ptr_as_ptr)]
364fn macos_memsize() -> Option<u64> {
365    use std::mem;
366    let mut memsize: u64 = 0;
367    let mut len = mem::size_of::<u64>();
368    let name = b"hw.memsize\0";
369    // SAFETY: `name` is a static, NUL-terminated sysctl name; `memsize` and
370    // `len` are live stack out-pointers whose sizes match the queried value.
371    let ret = unsafe {
372        libc::sysctlbyname(
373            name.as_ptr().cast(),
374            std::ptr::from_mut(&mut memsize).cast::<libc::c_void>(),
375            std::ptr::from_mut(&mut len),
376            std::ptr::null_mut(),
377            0,
378        )
379    };
380    if ret == 0 {
381        Some(memsize)
382    } else {
383        None
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390
391    #[test]
392    fn rss_returns_some_on_supported_os() {
393        if cfg!(any(target_os = "linux", target_os = "macos")) {
394            let rss = get_rss_bytes();
395            assert!(rss.is_some(), "RSS should be readable");
396            assert!(rss.unwrap() > 0, "RSS should be > 0");
397        }
398    }
399
400    #[test]
401    fn system_ram_returns_some_on_supported_os() {
402        if cfg!(any(target_os = "linux", target_os = "macos")) {
403            let ram = get_system_ram_bytes();
404            assert!(ram.is_some(), "System RAM should be readable");
405            assert!(ram.unwrap() > 1_000_000, "System RAM should be > 1MB");
406        }
407    }
408
409    #[test]
410    fn snapshot_captures_correctly() {
411        if cfg!(any(target_os = "linux", target_os = "macos")) {
412            let snap = MemorySnapshot::capture();
413            assert!(snap.is_some());
414            let s = snap.unwrap();
415            assert!(s.rss_bytes > 0);
416            assert!(s.system_ram_bytes > s.rss_bytes);
417            assert!(s.rss_percent > 0.0 && s.rss_percent < 100.0);
418        }
419    }
420
421    #[test]
422    fn peak_rss_tracks_maximum() {
423        PEAK_RSS.store(0, Ordering::Relaxed);
424        PEAK_RSS.fetch_max(100, Ordering::Relaxed);
425        PEAK_RSS.fetch_max(50, Ordering::Relaxed);
426        assert_eq!(PEAK_RSS.load(Ordering::Relaxed), 100);
427    }
428
429    #[test]
430    fn pressure_level_roundtrip() {
431        for level in [
432            PressureLevel::Normal,
433            PressureLevel::Soft,
434            PressureLevel::Medium,
435            PressureLevel::Hard,
436            PressureLevel::Critical,
437        ] {
438            assert_eq!(PressureLevel::from_u8(level as u8), level);
439        }
440    }
441
442    #[test]
443    fn atomic_pressure_defaults_to_normal() {
444        assert_eq!(current_pressure(), PressureLevel::Normal);
445    }
446
447    #[test]
448    fn rss_for_own_pid_matches_self() {
449        if cfg!(any(target_os = "linux", target_os = "macos")) {
450            let self_rss = get_rss_bytes().unwrap();
451            let pid_rss = get_rss_bytes_for_pid(std::process::id()).unwrap();
452            let ratio = self_rss as f64 / pid_rss as f64;
453            assert!(
454                (0.5..2.0).contains(&ratio),
455                "self RSS ({self_rss}) and pid-based RSS ({pid_rss}) should be within 2x"
456            );
457        }
458    }
459
460    #[test]
461    fn rss_for_dead_pid_returns_none() {
462        let dead_pid = 999_999_999u32;
463        assert!(get_rss_bytes_for_pid(dead_pid).is_none());
464    }
465
466    #[test]
467    fn capture_for_pid_falls_back_on_dead_pid() {
468        if cfg!(any(target_os = "linux", target_os = "macos")) {
469            let snap = MemorySnapshot::capture_for_pid(999_999_999);
470            assert!(snap.is_some(), "should fall back to self RSS");
471        }
472    }
473}