Skip to main content

rivet/
resource.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
3use std::thread::JoinHandle;
4use std::time::Duration;
5
6/// Background thread polls RSS while an export runs so `peak_rss_mb` reflects in-process highs,
7/// not only values at start/end (which often miss parallel workers' allocations).
8pub struct RssPeakSampler {
9    peak: Arc<AtomicUsize>,
10    stop: Arc<AtomicBool>,
11    handle: Option<JoinHandle<()>>,
12}
13
14impl RssPeakSampler {
15    /// Spawns a sampler every `interval_ms` (default-quality tradeoff: 100ms).
16    /// `seed_mb` is folded into the peak (typically RSS immediately before work starts).
17    pub fn start(seed_mb: usize, interval_ms: u64) -> Self {
18        let peak = Arc::new(AtomicUsize::new(seed_mb));
19        let stop = Arc::new(AtomicBool::new(false));
20        let peak_c = Arc::clone(&peak);
21        let stop_c = Arc::clone(&stop);
22        let handle = std::thread::Builder::new()
23            .name("rivet-rss-peak".into())
24            .spawn(move || {
25                while !stop_c.load(Ordering::Relaxed) {
26                    let r = get_rss_mb();
27                    peak_c.fetch_max(r, Ordering::Relaxed);
28                    std::thread::sleep(Duration::from_millis(interval_ms));
29                }
30                let r = get_rss_mb();
31                peak_c.fetch_max(r, Ordering::Relaxed);
32            })
33            .expect("spawn rss peak sampler");
34        Self {
35            peak,
36            stop,
37            handle: Some(handle),
38        }
39    }
40
41    /// Stops sampling and returns the highest RSS (MB) seen, including a final read.
42    pub fn stop(mut self) -> usize {
43        self.stop.store(true, Ordering::Relaxed);
44        if let Some(h) = self.handle.take() {
45            let _ = h.join();
46        }
47        let last = get_rss_mb();
48        self.peak.load(Ordering::Relaxed).max(last)
49    }
50}
51
52/// Returns the current process RSS in megabytes, or 0 if unavailable.
53pub fn get_rss_mb() -> usize {
54    #[cfg(target_os = "macos")]
55    {
56        macos_rss_mb()
57    }
58    #[cfg(target_os = "linux")]
59    {
60        linux_rss_mb()
61    }
62    #[cfg(not(any(target_os = "macos", target_os = "linux")))]
63    {
64        0
65    }
66}
67
68#[cfg(target_os = "macos")]
69fn macos_rss_mb() -> usize {
70    use std::mem;
71    // SAFETY: `task_info` is a stable macOS kernel API.  We zero-initialise the
72    // struct, pass the correct `count` (struct size / natural_t), and check the
73    // return code before reading the result.  No aliasing or lifetime issues.
74    unsafe {
75        let mut info: libc::mach_task_basic_info_data_t = mem::zeroed();
76        let mut count = (mem::size_of::<libc::mach_task_basic_info_data_t>()
77            / mem::size_of::<libc::natural_t>())
78            as libc::mach_msg_type_number_t;
79        let kr = libc::task_info(
80            mach2::traps::mach_task_self(),
81            libc::MACH_TASK_BASIC_INFO,
82            &mut info as *mut _ as libc::task_info_t,
83            &mut count,
84        );
85        if kr == libc::KERN_SUCCESS {
86            (info.resident_size as usize) / (1024 * 1024)
87        } else {
88            0
89        }
90    }
91}
92
93#[cfg(target_os = "linux")]
94fn linux_rss_mb() -> usize {
95    std::fs::read_to_string("/proc/self/statm")
96        .ok()
97        .and_then(|s| s.split_whitespace().nth(1)?.parse::<usize>().ok())
98        .map(|pages| pages * 4096 / (1024 * 1024))
99        .unwrap_or(0)
100}
101
102pub fn check_memory(threshold_mb: usize) -> bool {
103    if threshold_mb == 0 {
104        return true;
105    }
106    let rss = get_rss_mb();
107    if rss > threshold_mb {
108        log::warn!("RSS {}MB exceeds threshold {}MB", rss, threshold_mb);
109        return false;
110    }
111    true
112}
113
114/// Counting semaphore built on `Mutex + Condvar` so blocked acquirers park in
115/// the kernel rather than spinning on an atomic.
116///
117/// Replaces the prior pattern in `pipeline::chunked::exec`:
118/// ```ignore
119/// while atomic.load(Relaxed) >= max { thread::sleep(50ms); }
120/// atomic.fetch_add(1, Relaxed);
121/// ```
122/// which polled 20 times/sec per blocked worker. Under N parallel chunks and a
123/// long-running export this burned ~N × 20 wake-ups per second doing nothing.
124///
125/// With `Condvar::wait`, blocked threads consume zero CPU until a `release()`
126/// notifies. The mutex is uncontended whenever traffic is light, so the lock
127/// path adds no measurable overhead vs the atomic.
128pub struct Semaphore {
129    state: std::sync::Mutex<usize>,
130    cond: std::sync::Condvar,
131    max: usize,
132}
133
134impl Semaphore {
135    pub fn new(max: usize) -> Self {
136        Self {
137            state: std::sync::Mutex::new(0),
138            cond: std::sync::Condvar::new(),
139            max,
140        }
141    }
142
143    /// Block until a permit is available, then take one.
144    pub fn acquire(&self) {
145        let mut count = self
146            .state
147            .lock()
148            .unwrap_or_else(std::sync::PoisonError::into_inner);
149        while *count >= self.max {
150            count = self
151                .cond
152                .wait(count)
153                .unwrap_or_else(std::sync::PoisonError::into_inner);
154        }
155        *count += 1;
156    }
157
158    /// Return a permit and wake one blocked acquirer (if any).
159    pub fn release(&self) {
160        let mut count = self
161            .state
162            .lock()
163            .unwrap_or_else(std::sync::PoisonError::into_inner);
164        debug_assert!(*count > 0, "release without matching acquire");
165        *count -= 1;
166        self.cond.notify_one();
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    #[test]
175    fn check_memory_zero_threshold_always_passes() {
176        assert!(check_memory(0));
177    }
178
179    #[test]
180    fn check_memory_huge_threshold_passes() {
181        // No test process will reach 1 TB of RAM.
182        assert!(check_memory(1_024 * 1_024));
183    }
184
185    #[test]
186    fn get_rss_mb_does_not_panic() {
187        let _ = get_rss_mb();
188    }
189
190    #[test]
191    fn rss_peak_sampler_stop_returns_value() {
192        let sampler = RssPeakSampler::start(0, 50);
193        let _peak = sampler.stop();
194    }
195
196    #[test]
197    fn rss_peak_sampler_seed_is_lower_bound() {
198        let high_seed = 9999;
199        let sampler = RssPeakSampler::start(high_seed, 50);
200        let peak = sampler.stop();
201        assert!(peak >= high_seed);
202    }
203
204    // ── Semaphore ───────────────────────────────────────────────────────────
205
206    #[test]
207    fn semaphore_admits_up_to_max_without_blocking() {
208        let sem = Semaphore::new(3);
209        sem.acquire();
210        sem.acquire();
211        sem.acquire();
212        // Three permits taken, no deadlock so far → invariant holds.
213        sem.release();
214        sem.release();
215        sem.release();
216    }
217
218    #[test]
219    fn semaphore_blocks_fourth_until_release() {
220        use std::sync::Arc;
221        use std::sync::atomic::{AtomicBool, Ordering};
222
223        let sem = Arc::new(Semaphore::new(2));
224        sem.acquire();
225        sem.acquire();
226
227        let entered = Arc::new(AtomicBool::new(false));
228        let entered_w = Arc::clone(&entered);
229        let sem_w = Arc::clone(&sem);
230        let handle = std::thread::spawn(move || {
231            sem_w.acquire();
232            entered_w.store(true, Ordering::Release);
233            sem_w.release();
234        });
235
236        // Worker is parked in `acquire()` — give it a moment and confirm.
237        std::thread::sleep(std::time::Duration::from_millis(50));
238        assert!(
239            !entered.load(Ordering::Acquire),
240            "worker must be blocked while 2/2 permits are taken"
241        );
242
243        // Releasing one permit must wake the worker.
244        sem.release();
245        handle.join().expect("worker thread");
246        assert!(
247            entered.load(Ordering::Acquire),
248            "worker should have entered after release"
249        );
250        sem.release();
251    }
252}