use procfs::process::Process;
use std::collections::HashMap;
use std::io::{Error, ErrorKind};
use std::time::{Duration, Instant};
#[derive(Clone, Debug)]
struct CpuTimes {
user: u64, system: u64, timestamp: Instant,
}
#[derive(Debug)]
pub struct CpuSampler {
previous_times: HashMap<usize, CpuTimes>,
clock_ticks_per_sec: u64,
}
impl Default for CpuSampler {
fn default() -> Self {
Self::new()
}
}
impl CpuSampler {
pub fn new() -> Self {
let clock_ticks = unsafe { libc::sysconf(libc::_SC_CLK_TCK) } as u64;
Self {
previous_times: HashMap::new(),
clock_ticks_per_sec: clock_ticks,
}
}
pub fn get_cpu_usage_static(pid: usize) -> Result<f32, std::io::Error> {
let process = Process::new(pid as i32).map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("Process not found: {e}"),
)
})?;
let stat = process
.stat()
.map_err(|e| std::io::Error::other(format!("Failed to read process stat: {e}")))?;
let total_time = stat.utime + stat.stime;
let _clock_ticks = unsafe { libc::sysconf(libc::_SC_CLK_TCK) } as u64;
let uptime_ticks = stat.starttime;
if uptime_ticks > 0 {
let cpu_usage = (total_time as f64 / uptime_ticks as f64) * 100.0;
Ok(cpu_usage.min(100.0) as f32)
} else {
Ok(0.0)
}
}
pub fn get_cpu_usage(&mut self, pid: usize) -> Option<f32> {
let current = Self::read_process_times(pid).ok()?;
if let Some(previous) = self.previous_times.get(&pid) {
let time_delta = current.timestamp.duration_since(previous.timestamp);
if time_delta < Duration::from_millis(10) {
return None; }
let cpu_delta = (current.user + current.system) - (previous.user + previous.system);
let time_delta_ticks = time_delta.as_secs_f64() * self.clock_ticks_per_sec as f64;
let usage = (cpu_delta as f64 / time_delta_ticks) * 100.0;
self.previous_times.insert(pid, current);
Some(usage as f32)
} else {
self.previous_times.insert(pid, current);
None
}
}
fn read_process_times(pid: usize) -> Result<CpuTimes, std::io::Error> {
let process = Process::new(pid as i32).map_err(|e| {
Error::new(
ErrorKind::NotFound,
format!("Failed to access process {pid}: {e}"),
)
})?;
let stat = process.stat().map_err(|e| {
Error::new(
ErrorKind::InvalidData,
format!("Failed to read process stats: {e}"),
)
})?;
Ok(CpuTimes {
user: stat.utime,
system: stat.stime,
timestamp: Instant::now(),
})
}
pub fn cleanup_stale_entries(&mut self, active_pids: &[usize]) {
self.previous_times
.retain(|pid, _| active_pids.contains(pid));
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::process::{Child, Command};
#[test]
#[cfg(target_os = "linux")]
fn test_cpu_measurement_accuracy() {
let mut sampler = CpuSampler::new();
let child = Command::new("sh")
.arg("-c")
.arg("for i in $(seq 1 10000000); do let j=i*i; done")
.spawn()
.expect("Failed to spawn test process");
let pid = child.id() as usize;
assert!(sampler.get_cpu_usage(pid).is_none());
std::thread::sleep(Duration::from_millis(500));
let mut usage = 0.0;
for _ in 0..5 {
if let Some(u) = sampler.get_cpu_usage(pid) {
usage = u;
if usage > 0.0 {
break;
}
}
std::thread::sleep(Duration::from_millis(100));
}
assert!(usage > 0.0, "CPU usage should be greater than 0: {}", usage);
kill_child(child);
}
#[test]
#[cfg(target_os = "linux")]
fn test_read_process_times() {
let pid = std::process::id() as usize;
let times = CpuSampler::read_process_times(pid).expect("Failed to read process times");
println!(
"User CPU time: {}, System CPU time: {}",
times.user, times.system
);
for _ in 0..1000000 {
let _ = std::time::SystemTime::now();
}
let times_after =
CpuSampler::read_process_times(pid).expect("Failed to read process times");
assert!(
times_after.user > times.user || times_after.system > times.system,
"Either user or system CPU time should increase after doing work"
);
}
#[test]
#[cfg(target_os = "linux")]
fn test_cleanup_stale_entries() {
let mut sampler = CpuSampler::new();
let child1 = Command::new("sh")
.arg("-c")
.arg("sleep 2")
.spawn()
.expect("Failed to spawn test process");
let child2 = Command::new("sh")
.arg("-c")
.arg("sleep 2")
.spawn()
.expect("Failed to spawn test process");
let pid1 = child1.id() as usize;
let pid2 = child2.id() as usize;
sampler.get_cpu_usage(pid1);
sampler.get_cpu_usage(pid2);
assert!(sampler.previous_times.contains_key(&pid1));
assert!(sampler.previous_times.contains_key(&pid2));
sampler.cleanup_stale_entries(&[pid1]);
assert!(sampler.previous_times.contains_key(&pid1));
assert!(!sampler.previous_times.contains_key(&pid2));
kill_child(child1);
kill_child(child2);
}
fn kill_child(mut child: Child) {
child.kill().ok();
child.wait().ok();
}
#[test]
fn test_cpu_sampler_new() {
let sampler = CpuSampler::new();
assert_eq!(sampler.previous_times.len(), 0);
assert!(sampler.clock_ticks_per_sec > 0);
}
#[test]
fn test_cpu_sampler_default() {
let sampler = CpuSampler::default();
assert_eq!(sampler.previous_times.len(), 0);
assert!(sampler.clock_ticks_per_sec > 0);
}
#[test]
#[cfg(target_os = "linux")]
fn test_get_cpu_usage_static() {
let pid = std::process::id() as usize;
let result = CpuSampler::get_cpu_usage_static(pid);
assert!(result.is_ok());
let usage = result.unwrap();
assert!(usage >= 0.0);
assert!(usage <= 1000.0); }
#[test]
#[cfg(target_os = "linux")]
fn test_get_cpu_usage_static_invalid_pid() {
let result = CpuSampler::get_cpu_usage_static(999999);
assert!(result.is_err());
}
#[test]
#[cfg(target_os = "linux")]
fn test_read_process_times_invalid_pid() {
let result = CpuSampler::read_process_times(999999);
assert!(result.is_err());
}
#[test]
#[cfg(target_os = "linux")]
fn test_cpu_usage_first_measurement_returns_none() {
let mut sampler = CpuSampler::new();
let pid = std::process::id() as usize;
let result = sampler.get_cpu_usage(pid);
assert!(result.is_none());
assert!(sampler.previous_times.contains_key(&pid));
}
#[test]
#[cfg(target_os = "linux")]
fn test_cpu_usage_quick_successive_calls() {
let mut sampler = CpuSampler::new();
let pid = std::process::id() as usize;
sampler.get_cpu_usage(pid);
let result = sampler.get_cpu_usage(pid);
assert!(result.is_none());
}
#[test]
#[cfg(target_os = "linux")]
fn test_cpu_usage_with_delay() {
let mut sampler = CpuSampler::new();
let pid = std::process::id() as usize;
sampler.get_cpu_usage(pid);
std::thread::sleep(Duration::from_millis(50));
for _ in 0..100000 {
let _ = std::time::SystemTime::now();
}
let result = sampler.get_cpu_usage(pid);
assert!(result.is_some());
let usage = result.unwrap();
assert!(usage >= 0.0);
}
#[test]
#[cfg(target_os = "linux")]
fn test_cleanup_stale_entries_empty_active_list() {
let mut sampler = CpuSampler::new();
let pid = std::process::id() as usize;
sampler.get_cpu_usage(pid);
assert!(sampler.previous_times.contains_key(&pid));
sampler.cleanup_stale_entries(&[]);
assert!(!sampler.previous_times.contains_key(&pid));
assert_eq!(sampler.previous_times.len(), 0);
}
#[test]
#[cfg(target_os = "linux")]
fn test_multiple_pids_tracking() {
let mut sampler = CpuSampler::new();
let child1 = Command::new("sleep")
.arg("1")
.spawn()
.expect("Failed to spawn test process");
let child2 = Command::new("sleep")
.arg("1")
.spawn()
.expect("Failed to spawn test process");
let pid1 = child1.id() as usize;
let pid2 = child2.id() as usize;
sampler.get_cpu_usage(pid1);
sampler.get_cpu_usage(pid2);
assert!(sampler.previous_times.contains_key(&pid1));
assert!(sampler.previous_times.contains_key(&pid2));
assert_eq!(sampler.previous_times.len(), 2);
kill_child(child1);
kill_child(child2);
}
#[test]
fn test_cpu_times_clone() {
let times = CpuTimes {
user: 100,
system: 200,
timestamp: Instant::now(),
};
let cloned = times.clone();
assert_eq!(times.user, cloned.user);
assert_eq!(times.system, cloned.system);
}
#[test]
fn test_cpu_times_debug() {
let times = CpuTimes {
user: 100,
system: 200,
timestamp: Instant::now(),
};
let debug_str = format!("{:?}", times);
assert!(debug_str.contains("CpuTimes"));
assert!(debug_str.contains("user"));
assert!(debug_str.contains("system"));
}
#[test]
#[cfg(target_os = "linux")]
fn test_sampler_with_terminated_process() {
let mut sampler = CpuSampler::new();
let mut child = Command::new("true")
.spawn()
.expect("Failed to spawn test process");
let pid = child.id() as usize;
let _ = child.wait();
let result = sampler.get_cpu_usage(pid);
assert!(result.is_none());
}
}