use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
pub fn current_rss_bytes() -> crate::Result<u64> {
#[cfg(target_os = "linux")]
{
linux_rss()
}
#[cfg(target_os = "macos")]
{
macos_rss()
}
#[cfg(not(any(target_os = "linux", target_os = "macos")))]
{
Err(crate::BenchError::MemoryUnavailable(
"memory sampling not supported on this platform".to_string(),
))
}
}
#[cfg(target_os = "linux")]
fn linux_rss() -> crate::Result<u64> {
use std::fs;
let statm = fs::read_to_string("/proc/self/statm").map_err(|e| {
crate::BenchError::MemoryUnavailable(format!("failed to read /proc/self/statm: {e}"))
})?;
let rss_pages: u64 =
statm.split_whitespace().nth(1).and_then(|s| s.parse().ok()).ok_or_else(|| {
crate::BenchError::MemoryUnavailable(
"failed to parse RSS from /proc/self/statm".to_string(),
)
})?;
let page_size = unsafe { libc::sysconf(libc::_SC_PAGESIZE) } as u64;
Ok(rss_pages * page_size)
}
#[cfg(target_os = "macos")]
fn macos_rss() -> crate::Result<u64> {
use std::mem;
unsafe {
let mut info: libc::mach_task_basic_info_data_t = mem::zeroed();
let mut count = libc::MACH_TASK_BASIC_INFO_COUNT;
#[allow(deprecated)]
let task_self = libc::mach_task_self();
let ret = libc::task_info(
task_self,
libc::MACH_TASK_BASIC_INFO,
(&raw mut info) as libc::task_info_t,
&raw mut count,
);
if ret != libc::KERN_SUCCESS {
return Err(crate::BenchError::MemoryUnavailable(format!(
"task_info(MACH_TASK_BASIC_INFO) failed with kern_return_t={ret}"
)));
}
Ok(info.resident_size)
}
}
pub fn spawn_memory_sampler(interval: Duration) -> (JoinHandle<()>, Arc<Mutex<Vec<u64>>>) {
let samples = Arc::new(Mutex::new(Vec::new()));
let samples_clone = samples.clone();
let handle = tokio::spawn(async move {
loop {
if let Ok(rss) = current_rss_bytes() {
samples_clone.lock().await.push(rss);
}
tokio::time::sleep(interval).await;
}
});
(handle, samples)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg(any(target_os = "linux", target_os = "macos"))]
fn test_current_rss_bytes_returns_positive() {
let rss = current_rss_bytes().expect("RSS sampling should succeed on this platform");
assert!(rss > 0, "RSS must be positive, got {rss}");
}
#[test]
#[cfg(any(target_os = "linux", target_os = "macos"))]
fn test_current_rss_bytes_reasonable_range() {
let rss = current_rss_bytes().expect("RSS sampling should succeed");
assert!(rss >= 1_000_000, "RSS should be at least 1 MB, got {rss} bytes");
assert!(rss < 64 * 1024 * 1024 * 1024, "RSS should be less than 64 GB, got {rss} bytes");
}
#[test]
#[cfg(not(any(target_os = "linux", target_os = "macos")))]
fn test_current_rss_bytes_unsupported_platform() {
let result = current_rss_bytes();
assert!(result.is_err());
}
#[tokio::test]
#[cfg(any(target_os = "linux", target_os = "macos"))]
async fn test_spawn_memory_sampler_collects_samples() {
let interval = Duration::from_millis(10);
let (handle, samples) = spawn_memory_sampler(interval);
tokio::time::sleep(Duration::from_millis(60)).await;
handle.abort();
let data = samples.lock().await;
assert!(!data.is_empty(), "sampler should have collected at least one sample");
for &sample in data.iter() {
assert!(sample > 0, "each RSS sample must be positive");
}
}
#[tokio::test]
#[cfg(any(target_os = "linux", target_os = "macos"))]
async fn test_spawn_memory_sampler_abort_stops_collection() {
let interval = Duration::from_millis(10);
let (handle, samples) = spawn_memory_sampler(interval);
tokio::time::sleep(Duration::from_millis(30)).await;
handle.abort();
let count_after_abort = samples.lock().await.len();
tokio::time::sleep(Duration::from_millis(50)).await;
let count_later = samples.lock().await.len();
assert_eq!(count_after_abort, count_later, "no new samples should appear after abort");
}
}