1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
3use std::thread::JoinHandle;
4use std::time::Duration;
5
6pub struct RssPeakSampler {
9 peak: Arc<AtomicUsize>,
10 stop: Arc<AtomicBool>,
11 handle: Option<JoinHandle<()>>,
12}
13
14impl RssPeakSampler {
15 pub fn start(seed_mb: usize, interval_ms: u64) -> Self {
18 let peak = Arc::new(AtomicUsize::new(seed_mb));
19 let stop = Arc::new(AtomicBool::new(false));
20 let peak_c = Arc::clone(&peak);
21 let stop_c = Arc::clone(&stop);
22 let handle = std::thread::Builder::new()
23 .name("rivet-rss-peak".into())
24 .spawn(move || {
25 while !stop_c.load(Ordering::Relaxed) {
26 let r = get_rss_mb();
27 peak_c.fetch_max(r, Ordering::Relaxed);
28 std::thread::sleep(Duration::from_millis(interval_ms));
29 }
30 let r = get_rss_mb();
31 peak_c.fetch_max(r, Ordering::Relaxed);
32 })
33 .expect("spawn rss peak sampler");
34 Self {
35 peak,
36 stop,
37 handle: Some(handle),
38 }
39 }
40
41 pub fn stop(mut self) -> usize {
43 self.stop.store(true, Ordering::Relaxed);
44 if let Some(h) = self.handle.take() {
45 let _ = h.join();
46 }
47 let last = get_rss_mb();
48 self.peak.load(Ordering::Relaxed).max(last)
49 }
50}
51
52pub fn get_rss_mb() -> usize {
54 #[cfg(target_os = "macos")]
55 {
56 macos_rss_mb()
57 }
58 #[cfg(target_os = "linux")]
59 {
60 linux_rss_mb()
61 }
62 #[cfg(not(any(target_os = "macos", target_os = "linux")))]
63 {
64 0
65 }
66}
67
68#[cfg(target_os = "macos")]
69fn macos_rss_mb() -> usize {
70 use std::mem;
71 unsafe {
75 let mut info: libc::mach_task_basic_info_data_t = mem::zeroed();
76 let mut count = (mem::size_of::<libc::mach_task_basic_info_data_t>()
77 / mem::size_of::<libc::natural_t>())
78 as libc::mach_msg_type_number_t;
79 let kr = libc::task_info(
80 mach2::traps::mach_task_self(),
81 libc::MACH_TASK_BASIC_INFO,
82 &mut info as *mut _ as libc::task_info_t,
83 &mut count,
84 );
85 if kr == libc::KERN_SUCCESS {
86 (info.resident_size as usize) / (1024 * 1024)
87 } else {
88 0
89 }
90 }
91}
92
93#[cfg(target_os = "linux")]
94fn linux_rss_mb() -> usize {
95 std::fs::read_to_string("/proc/self/statm")
96 .ok()
97 .and_then(|s| s.split_whitespace().nth(1)?.parse::<usize>().ok())
98 .map(|pages| pages * 4096 / (1024 * 1024))
99 .unwrap_or(0)
100}
101
102pub fn check_memory(threshold_mb: usize) -> bool {
103 if threshold_mb == 0 {
104 return true;
105 }
106 let rss = get_rss_mb();
107 if rss > threshold_mb {
108 log::warn!("RSS {}MB exceeds threshold {}MB", rss, threshold_mb);
109 return false;
110 }
111 true
112}
113
114pub struct Semaphore {
133 state: std::sync::Mutex<SemState>,
134 cond: std::sync::Condvar,
135}
136
137struct SemState {
138 count: usize,
140 max: usize,
142}
143
144impl Semaphore {
145 pub fn new(max: usize) -> Self {
146 Self {
147 state: std::sync::Mutex::new(SemState { count: 0, max }),
148 cond: std::sync::Condvar::new(),
149 }
150 }
151
152 pub fn acquire(&self) {
154 let mut st = self
155 .state
156 .lock()
157 .unwrap_or_else(std::sync::PoisonError::into_inner);
158 while st.count >= st.max {
159 st = self
160 .cond
161 .wait(st)
162 .unwrap_or_else(std::sync::PoisonError::into_inner);
163 }
164 st.count += 1;
165 }
166
167 pub fn release(&self) {
169 let mut st = self
170 .state
171 .lock()
172 .unwrap_or_else(std::sync::PoisonError::into_inner);
173 debug_assert!(st.count > 0, "release without matching acquire");
174 st.count -= 1;
175 self.cond.notify_one();
176 }
177
178 pub fn resize(&self, new_max: usize) {
186 let mut st = self
187 .state
188 .lock()
189 .unwrap_or_else(std::sync::PoisonError::into_inner);
190 let raised = new_max > st.max;
191 st.max = new_max;
192 if raised {
193 self.cond.notify_all();
194 }
195 }
196
197 #[cfg(test)]
199 pub fn current_max(&self) -> usize {
200 self.state
201 .lock()
202 .unwrap_or_else(std::sync::PoisonError::into_inner)
203 .max
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn check_memory_zero_threshold_always_passes() {
213 assert!(check_memory(0));
214 }
215
216 #[test]
217 fn check_memory_huge_threshold_passes() {
218 assert!(check_memory(1_024 * 1_024));
220 }
221
222 #[test]
223 fn get_rss_mb_does_not_panic() {
224 let _ = get_rss_mb();
225 }
226
227 #[test]
228 fn rss_peak_sampler_stop_returns_value() {
229 let sampler = RssPeakSampler::start(0, 50);
230 let _peak = sampler.stop();
231 }
232
233 #[test]
234 fn rss_peak_sampler_seed_is_lower_bound() {
235 let high_seed = 9999;
236 let sampler = RssPeakSampler::start(high_seed, 50);
237 let peak = sampler.stop();
238 assert!(peak >= high_seed);
239 }
240
241 #[test]
244 fn semaphore_admits_up_to_max_without_blocking() {
245 let sem = Semaphore::new(3);
246 sem.acquire();
247 sem.acquire();
248 sem.acquire();
249 sem.release();
251 sem.release();
252 sem.release();
253 }
254
255 #[test]
256 fn semaphore_blocks_fourth_until_release() {
257 use std::sync::Arc;
258 use std::sync::atomic::{AtomicBool, Ordering};
259
260 let sem = Arc::new(Semaphore::new(2));
261 sem.acquire();
262 sem.acquire();
263
264 let entered = Arc::new(AtomicBool::new(false));
265 let entered_w = Arc::clone(&entered);
266 let sem_w = Arc::clone(&sem);
267 let handle = std::thread::spawn(move || {
268 sem_w.acquire();
269 entered_w.store(true, Ordering::Release);
270 sem_w.release();
271 });
272
273 std::thread::sleep(std::time::Duration::from_millis(50));
275 assert!(
276 !entered.load(Ordering::Acquire),
277 "worker must be blocked while 2/2 permits are taken"
278 );
279
280 sem.release();
282 handle.join().expect("worker thread");
283 assert!(
284 entered.load(Ordering::Acquire),
285 "worker should have entered after release"
286 );
287 sem.release();
288 }
289
290 #[test]
291 fn semaphore_current_max_reports_resize() {
292 let sem = Semaphore::new(2);
293 assert_eq!(sem.current_max(), 2);
294 sem.resize(5);
295 assert_eq!(sem.current_max(), 5);
296 sem.resize(1);
297 assert_eq!(sem.current_max(), 1);
298 }
299
300 #[test]
301 fn semaphore_resize_up_wakes_parked_acquirer() {
302 use std::sync::Arc;
303 use std::sync::atomic::{AtomicBool, Ordering};
304
305 let sem = Arc::new(Semaphore::new(1));
307 sem.acquire();
308
309 let entered = Arc::new(AtomicBool::new(false));
310 let entered_w = Arc::clone(&entered);
311 let sem_w = Arc::clone(&sem);
312 let handle = std::thread::spawn(move || {
313 sem_w.acquire();
314 entered_w.store(true, Ordering::Release);
315 sem_w.release();
316 });
317
318 std::thread::sleep(std::time::Duration::from_millis(50));
319 assert!(
320 !entered.load(Ordering::Acquire),
321 "worker must be parked while 1/1 permits are taken"
322 );
323
324 sem.resize(2);
326 handle.join().expect("worker thread");
327 assert!(
328 entered.load(Ordering::Acquire),
329 "raising the ceiling should admit the parked worker"
330 );
331 sem.release();
332 }
333
334 #[test]
335 fn semaphore_resize_down_blocks_new_acquire_until_count_drops() {
336 use std::sync::Arc;
337 use std::sync::atomic::{AtomicBool, Ordering};
338
339 let sem = Arc::new(Semaphore::new(2));
341 sem.acquire();
342 sem.acquire();
343 sem.resize(1);
344
345 sem.release();
348
349 let entered = Arc::new(AtomicBool::new(false));
350 let entered_w = Arc::clone(&entered);
351 let sem_w = Arc::clone(&sem);
352 let handle = std::thread::spawn(move || {
353 sem_w.acquire();
354 entered_w.store(true, Ordering::Release);
355 sem_w.release();
356 });
357
358 std::thread::sleep(std::time::Duration::from_millis(50));
359 assert!(
360 !entered.load(Ordering::Acquire),
361 "count(1) >= new max(1): acquirer must block after shrink"
362 );
363
364 sem.release();
368 handle.join().expect("worker thread");
369 assert!(
370 entered.load(Ordering::Acquire),
371 "acquirer should proceed once count falls below the new ceiling"
372 );
373 }
374}