Skip to main content

rivet_cli/
resource.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
3use std::thread::JoinHandle;
4use std::time::Duration;
5
6/// Background thread polls RSS while an export runs so `peak_rss_mb` reflects in-process highs,
7/// not only values at start/end (which often miss parallel workers' allocations).
8pub struct RssPeakSampler {
9    peak: Arc<AtomicUsize>,
10    stop: Arc<AtomicBool>,
11    handle: Option<JoinHandle<()>>,
12}
13
14impl RssPeakSampler {
15    /// Spawns a sampler every `interval_ms` (default-quality tradeoff: 100ms).
16    /// `seed_mb` is folded into the peak (typically RSS immediately before work starts).
17    pub fn start(seed_mb: usize, interval_ms: u64) -> Self {
18        let peak = Arc::new(AtomicUsize::new(seed_mb));
19        let stop = Arc::new(AtomicBool::new(false));
20        let peak_c = Arc::clone(&peak);
21        let stop_c = Arc::clone(&stop);
22        let handle = std::thread::Builder::new()
23            .name("rivet-rss-peak".into())
24            .spawn(move || {
25                while !stop_c.load(Ordering::Relaxed) {
26                    let r = get_rss_mb();
27                    peak_c.fetch_max(r, Ordering::Relaxed);
28                    std::thread::sleep(Duration::from_millis(interval_ms));
29                }
30                let r = get_rss_mb();
31                peak_c.fetch_max(r, Ordering::Relaxed);
32            })
33            .expect("spawn rss peak sampler");
34        Self {
35            peak,
36            stop,
37            handle: Some(handle),
38        }
39    }
40
41    /// Stops sampling and returns the highest RSS (MB) seen, including a final read.
42    pub fn stop(mut self) -> usize {
43        self.stop.store(true, Ordering::Relaxed);
44        if let Some(h) = self.handle.take() {
45            let _ = h.join();
46        }
47        let last = get_rss_mb();
48        self.peak.load(Ordering::Relaxed).max(last)
49    }
50}
51
52/// Returns the current process RSS in megabytes, or 0 if unavailable.
53pub fn get_rss_mb() -> usize {
54    #[cfg(target_os = "macos")]
55    {
56        macos_rss_mb()
57    }
58    #[cfg(target_os = "linux")]
59    {
60        linux_rss_mb()
61    }
62    #[cfg(not(any(target_os = "macos", target_os = "linux")))]
63    {
64        0
65    }
66}
67
68#[cfg(target_os = "macos")]
69fn macos_rss_mb() -> usize {
70    use std::mem;
71    // SAFETY: `task_info` is a stable macOS kernel API.  We zero-initialise the
72    // struct, pass the correct `count` (struct size / natural_t), and check the
73    // return code before reading the result.  No aliasing or lifetime issues.
74    unsafe {
75        let mut info: libc::mach_task_basic_info_data_t = mem::zeroed();
76        let mut count = (mem::size_of::<libc::mach_task_basic_info_data_t>()
77            / mem::size_of::<libc::natural_t>())
78            as libc::mach_msg_type_number_t;
79        let kr = libc::task_info(
80            mach2::traps::mach_task_self(),
81            libc::MACH_TASK_BASIC_INFO,
82            &mut info as *mut _ as libc::task_info_t,
83            &mut count,
84        );
85        if kr == libc::KERN_SUCCESS {
86            (info.resident_size as usize) / (1024 * 1024)
87        } else {
88            0
89        }
90    }
91}
92
93#[cfg(target_os = "linux")]
94fn linux_rss_mb() -> usize {
95    std::fs::read_to_string("/proc/self/statm")
96        .ok()
97        .and_then(|s| s.split_whitespace().nth(1)?.parse::<usize>().ok())
98        .map(|pages| pages * 4096 / (1024 * 1024))
99        .unwrap_or(0)
100}
101
102pub fn check_memory(threshold_mb: usize) -> bool {
103    if threshold_mb == 0 {
104        return true;
105    }
106    let rss = get_rss_mb();
107    if rss > threshold_mb {
108        log::warn!("RSS {}MB exceeds threshold {}MB", rss, threshold_mb);
109        return false;
110    }
111    true
112}