1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
3use std::thread::JoinHandle;
4use std::time::Duration;
5
6pub struct RssPeakSampler {
9 peak: Arc<AtomicUsize>,
10 stop: Arc<AtomicBool>,
11 handle: Option<JoinHandle<()>>,
12}
13
14impl RssPeakSampler {
15 pub fn start(seed_mb: usize, interval_ms: u64) -> Self {
18 let peak = Arc::new(AtomicUsize::new(seed_mb));
19 let stop = Arc::new(AtomicBool::new(false));
20 let peak_c = Arc::clone(&peak);
21 let stop_c = Arc::clone(&stop);
22 let handle = std::thread::Builder::new()
23 .name("rivet-rss-peak".into())
24 .spawn(move || {
25 while !stop_c.load(Ordering::Relaxed) {
26 let r = get_rss_mb();
27 peak_c.fetch_max(r, Ordering::Relaxed);
28 std::thread::sleep(Duration::from_millis(interval_ms));
29 }
30 let r = get_rss_mb();
31 peak_c.fetch_max(r, Ordering::Relaxed);
32 })
33 .expect("spawn rss peak sampler");
34 Self {
35 peak,
36 stop,
37 handle: Some(handle),
38 }
39 }
40
41 pub fn stop(mut self) -> usize {
43 self.stop.store(true, Ordering::Relaxed);
44 if let Some(h) = self.handle.take() {
45 let _ = h.join();
46 }
47 let last = get_rss_mb();
48 self.peak.load(Ordering::Relaxed).max(last)
49 }
50}
51
52pub fn get_rss_mb() -> usize {
54 #[cfg(target_os = "macos")]
55 {
56 macos_rss_mb()
57 }
58 #[cfg(target_os = "linux")]
59 {
60 linux_rss_mb()
61 }
62 #[cfg(not(any(target_os = "macos", target_os = "linux")))]
63 {
64 0
65 }
66}
67
68#[cfg(target_os = "macos")]
69fn macos_rss_mb() -> usize {
70 use std::mem;
71 unsafe {
75 let mut info: libc::mach_task_basic_info_data_t = mem::zeroed();
76 let mut count = (mem::size_of::<libc::mach_task_basic_info_data_t>()
77 / mem::size_of::<libc::natural_t>())
78 as libc::mach_msg_type_number_t;
79 let kr = libc::task_info(
80 mach2::traps::mach_task_self(),
81 libc::MACH_TASK_BASIC_INFO,
82 &mut info as *mut _ as libc::task_info_t,
83 &mut count,
84 );
85 if kr == libc::KERN_SUCCESS {
86 (info.resident_size as usize) / (1024 * 1024)
87 } else {
88 0
89 }
90 }
91}
92
93#[cfg(target_os = "linux")]
94fn linux_rss_mb() -> usize {
95 std::fs::read_to_string("/proc/self/statm")
96 .ok()
97 .and_then(|s| s.split_whitespace().nth(1)?.parse::<usize>().ok())
98 .map(|pages| pages * 4096 / (1024 * 1024))
99 .unwrap_or(0)
100}
101
102pub fn check_memory(threshold_mb: usize) -> bool {
103 if threshold_mb == 0 {
104 return true;
105 }
106 let rss = get_rss_mb();
107 if rss > threshold_mb {
108 log::warn!("RSS {}MB exceeds threshold {}MB", rss, threshold_mb);
109 return false;
110 }
111 true
112}
113
114pub struct Semaphore {
129 state: std::sync::Mutex<usize>,
130 cond: std::sync::Condvar,
131 max: usize,
132}
133
134impl Semaphore {
135 pub fn new(max: usize) -> Self {
136 Self {
137 state: std::sync::Mutex::new(0),
138 cond: std::sync::Condvar::new(),
139 max,
140 }
141 }
142
143 pub fn acquire(&self) {
145 let mut count = self
146 .state
147 .lock()
148 .unwrap_or_else(std::sync::PoisonError::into_inner);
149 while *count >= self.max {
150 count = self
151 .cond
152 .wait(count)
153 .unwrap_or_else(std::sync::PoisonError::into_inner);
154 }
155 *count += 1;
156 }
157
158 pub fn release(&self) {
160 let mut count = self
161 .state
162 .lock()
163 .unwrap_or_else(std::sync::PoisonError::into_inner);
164 debug_assert!(*count > 0, "release without matching acquire");
165 *count -= 1;
166 self.cond.notify_one();
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[test]
175 fn check_memory_zero_threshold_always_passes() {
176 assert!(check_memory(0));
177 }
178
179 #[test]
180 fn check_memory_huge_threshold_passes() {
181 assert!(check_memory(1_024 * 1_024));
183 }
184
185 #[test]
186 fn get_rss_mb_does_not_panic() {
187 let _ = get_rss_mb();
188 }
189
190 #[test]
191 fn rss_peak_sampler_stop_returns_value() {
192 let sampler = RssPeakSampler::start(0, 50);
193 let _peak = sampler.stop();
194 }
195
196 #[test]
197 fn rss_peak_sampler_seed_is_lower_bound() {
198 let high_seed = 9999;
199 let sampler = RssPeakSampler::start(high_seed, 50);
200 let peak = sampler.stop();
201 assert!(peak >= high_seed);
202 }
203
204 #[test]
207 fn semaphore_admits_up_to_max_without_blocking() {
208 let sem = Semaphore::new(3);
209 sem.acquire();
210 sem.acquire();
211 sem.acquire();
212 sem.release();
214 sem.release();
215 sem.release();
216 }
217
218 #[test]
219 fn semaphore_blocks_fourth_until_release() {
220 use std::sync::Arc;
221 use std::sync::atomic::{AtomicBool, Ordering};
222
223 let sem = Arc::new(Semaphore::new(2));
224 sem.acquire();
225 sem.acquire();
226
227 let entered = Arc::new(AtomicBool::new(false));
228 let entered_w = Arc::clone(&entered);
229 let sem_w = Arc::clone(&sem);
230 let handle = std::thread::spawn(move || {
231 sem_w.acquire();
232 entered_w.store(true, Ordering::Release);
233 sem_w.release();
234 });
235
236 std::thread::sleep(std::time::Duration::from_millis(50));
238 assert!(
239 !entered.load(Ordering::Acquire),
240 "worker must be blocked while 2/2 permits are taken"
241 );
242
243 sem.release();
245 handle.join().expect("worker thread");
246 assert!(
247 entered.load(Ordering::Acquire),
248 "worker should have entered after release"
249 );
250 sem.release();
251 }
252}