use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::thread::JoinHandle;
use std::time::Duration;
pub struct RssPeakSampler {
peak: Arc<AtomicUsize>,
stop: Arc<AtomicBool>,
handle: Option<JoinHandle<()>>,
}
impl RssPeakSampler {
pub fn start(seed_mb: usize, interval_ms: u64) -> Self {
let peak = Arc::new(AtomicUsize::new(seed_mb));
let stop = Arc::new(AtomicBool::new(false));
let peak_c = Arc::clone(&peak);
let stop_c = Arc::clone(&stop);
let handle = std::thread::Builder::new()
.name("rivet-rss-peak".into())
.spawn(move || {
while !stop_c.load(Ordering::Relaxed) {
let r = get_rss_mb();
peak_c.fetch_max(r, Ordering::Relaxed);
std::thread::sleep(Duration::from_millis(interval_ms));
}
let r = get_rss_mb();
peak_c.fetch_max(r, Ordering::Relaxed);
})
.expect("spawn rss peak sampler");
Self {
peak,
stop,
handle: Some(handle),
}
}
pub fn stop(mut self) -> usize {
self.stop.store(true, Ordering::Relaxed);
if let Some(h) = self.handle.take() {
let _ = h.join();
}
let last = get_rss_mb();
self.peak.load(Ordering::Relaxed).max(last)
}
}
pub fn get_rss_mb() -> usize {
#[cfg(target_os = "macos")]
{
macos_rss_mb()
}
#[cfg(target_os = "linux")]
{
linux_rss_mb()
}
#[cfg(not(any(target_os = "macos", target_os = "linux")))]
{
0
}
}
#[cfg(target_os = "macos")]
fn macos_rss_mb() -> usize {
use std::mem;
unsafe {
let mut info: libc::mach_task_basic_info_data_t = mem::zeroed();
let mut count = (mem::size_of::<libc::mach_task_basic_info_data_t>()
/ mem::size_of::<libc::natural_t>())
as libc::mach_msg_type_number_t;
let kr = libc::task_info(
mach2::traps::mach_task_self(),
libc::MACH_TASK_BASIC_INFO,
&mut info as *mut _ as libc::task_info_t,
&mut count,
);
if kr == libc::KERN_SUCCESS {
(info.resident_size as usize) / (1024 * 1024)
} else {
0
}
}
}
#[cfg(target_os = "linux")]
fn linux_rss_mb() -> usize {
std::fs::read_to_string("/proc/self/statm")
.ok()
.and_then(|s| s.split_whitespace().nth(1)?.parse::<usize>().ok())
.map(|pages| pages * 4096 / (1024 * 1024))
.unwrap_or(0)
}
pub fn check_memory(threshold_mb: usize) -> bool {
if threshold_mb == 0 {
return true;
}
let rss = get_rss_mb();
if rss > threshold_mb {
log::warn!("RSS {}MB exceeds threshold {}MB", rss, threshold_mb);
return false;
}
true
}
pub struct Semaphore {
state: std::sync::Mutex<SemState>,
cond: std::sync::Condvar,
}
struct SemState {
count: usize,
max: usize,
}
impl Semaphore {
pub fn new(max: usize) -> Self {
Self {
state: std::sync::Mutex::new(SemState { count: 0, max }),
cond: std::sync::Condvar::new(),
}
}
pub fn acquire(&self) {
let mut st = self
.state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
while st.count >= st.max {
st = self
.cond
.wait(st)
.unwrap_or_else(std::sync::PoisonError::into_inner);
}
st.count += 1;
}
pub fn release(&self) {
let mut st = self
.state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
debug_assert!(st.count > 0, "release without matching acquire");
st.count -= 1;
self.cond.notify_one();
}
pub fn resize(&self, new_max: usize) {
let mut st = self
.state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let raised = new_max > st.max;
st.max = new_max;
if raised {
self.cond.notify_all();
}
}
#[cfg(test)]
pub fn current_max(&self) -> usize {
self.state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.max
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn check_memory_zero_threshold_always_passes() {
assert!(check_memory(0));
}
#[test]
fn check_memory_huge_threshold_passes() {
assert!(check_memory(1_024 * 1_024));
}
#[test]
fn get_rss_mb_does_not_panic() {
let _ = get_rss_mb();
}
#[test]
fn rss_peak_sampler_stop_returns_value() {
let sampler = RssPeakSampler::start(0, 50);
let _peak = sampler.stop();
}
#[test]
fn rss_peak_sampler_seed_is_lower_bound() {
let high_seed = 9999;
let sampler = RssPeakSampler::start(high_seed, 50);
let peak = sampler.stop();
assert!(peak >= high_seed);
}
#[test]
fn semaphore_admits_up_to_max_without_blocking() {
let sem = Semaphore::new(3);
sem.acquire();
sem.acquire();
sem.acquire();
sem.release();
sem.release();
sem.release();
}
#[test]
fn semaphore_blocks_fourth_until_release() {
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
let sem = Arc::new(Semaphore::new(2));
sem.acquire();
sem.acquire();
let entered = Arc::new(AtomicBool::new(false));
let entered_w = Arc::clone(&entered);
let sem_w = Arc::clone(&sem);
let handle = std::thread::spawn(move || {
sem_w.acquire();
entered_w.store(true, Ordering::Release);
sem_w.release();
});
std::thread::sleep(std::time::Duration::from_millis(50));
assert!(
!entered.load(Ordering::Acquire),
"worker must be blocked while 2/2 permits are taken"
);
sem.release();
handle.join().expect("worker thread");
assert!(
entered.load(Ordering::Acquire),
"worker should have entered after release"
);
sem.release();
}
#[test]
fn semaphore_current_max_reports_resize() {
let sem = Semaphore::new(2);
assert_eq!(sem.current_max(), 2);
sem.resize(5);
assert_eq!(sem.current_max(), 5);
sem.resize(1);
assert_eq!(sem.current_max(), 1);
}
#[test]
fn semaphore_resize_up_wakes_parked_acquirer() {
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
let sem = Arc::new(Semaphore::new(1));
sem.acquire();
let entered = Arc::new(AtomicBool::new(false));
let entered_w = Arc::clone(&entered);
let sem_w = Arc::clone(&sem);
let handle = std::thread::spawn(move || {
sem_w.acquire();
entered_w.store(true, Ordering::Release);
sem_w.release();
});
std::thread::sleep(std::time::Duration::from_millis(50));
assert!(
!entered.load(Ordering::Acquire),
"worker must be parked while 1/1 permits are taken"
);
sem.resize(2);
handle.join().expect("worker thread");
assert!(
entered.load(Ordering::Acquire),
"raising the ceiling should admit the parked worker"
);
sem.release();
}
#[test]
fn semaphore_resize_down_blocks_new_acquire_until_count_drops() {
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
let sem = Arc::new(Semaphore::new(2));
sem.acquire();
sem.acquire();
sem.resize(1);
sem.release();
let entered = Arc::new(AtomicBool::new(false));
let entered_w = Arc::clone(&entered);
let sem_w = Arc::clone(&sem);
let handle = std::thread::spawn(move || {
sem_w.acquire();
entered_w.store(true, Ordering::Release);
sem_w.release();
});
std::thread::sleep(std::time::Duration::from_millis(50));
assert!(
!entered.load(Ordering::Acquire),
"count(1) >= new max(1): acquirer must block after shrink"
);
sem.release();
handle.join().expect("worker thread");
assert!(
entered.load(Ordering::Acquire),
"acquirer should proceed once count falls below the new ceiling"
);
}
}