use std::collections::VecDeque;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::thread::{self, JoinHandle};
use std::time::{Duration, Instant};
use anyhow::Result;
use crate::KTSTR_STALL_POLL_MS_ENV;
pub const DEFAULT_POLL_INTERVAL_MS: u64 = 500;
pub const STALL_WINDOW: usize = 4;
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct SchedSample {
pub nr_switches: u64,
pub sum_exec_runtime_ns: u64,
#[serde(skip, default = "Instant::now")]
pub captured_at: Instant,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct StallDiagnostic {
pub wchan: String,
pub syscall: String,
pub state: String,
pub stack: Option<String>,
pub status_full: String,
pub cgroup: String,
pub host_loadavg: String,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct StallReport {
pub pid: libc::pid_t,
pub comm: String,
pub samples: Vec<SchedSample>,
#[serde(skip, default = "Instant::now")]
pub captured_at: Instant,
pub diagnostic: StallDiagnostic,
}
pub(crate) struct StallMonitor {
shutdown: Arc<AtomicBool>,
reports: Arc<Mutex<Vec<StallReport>>>,
}
#[must_use = "StallMonitorHandle stops polling on Drop; bind it to a local for the scenario lifetime"]
pub(crate) struct StallMonitorHandle {
monitor: StallMonitor,
thread: Option<JoinHandle<()>>,
}
impl StallMonitorHandle {
pub(crate) fn drain(&self) -> Vec<StallReport> {
let mut guard = self
.monitor
.reports
.lock()
.expect("stall-monitor reports mutex poisoned");
std::mem::take(&mut *guard)
}
}
impl Drop for StallMonitorHandle {
fn drop(&mut self) {
self.monitor.shutdown.store(true, Ordering::SeqCst);
if let Some(handle) = self.thread.take() {
if let Err(e) = handle.join() {
tracing::warn!(?e, "stall-monitor polling thread panicked");
}
}
}
}
pub(crate) fn spawn_monitor(pids: &[libc::pid_t]) -> Result<StallMonitorHandle> {
let interval = resolve_poll_interval();
let pids: Vec<libc::pid_t> = pids.to_vec();
let shutdown = Arc::new(AtomicBool::new(false));
let reports: Arc<Mutex<Vec<StallReport>>> = Arc::new(Mutex::new(Vec::new()));
let thread_shutdown = Arc::clone(&shutdown);
let thread_reports = Arc::clone(&reports);
let thread = thread::Builder::new()
.name("ktstr-stall-mon".to_string())
.spawn(move || {
poll_loop(pids, interval, thread_shutdown, thread_reports);
})
.map_err(|e| anyhow::anyhow!("failed to spawn stall-monitor thread: {e}"))?;
Ok(StallMonitorHandle {
monitor: StallMonitor { shutdown, reports },
thread: Some(thread),
})
}
fn resolve_poll_interval() -> Duration {
let ms = std::env::var(KTSTR_STALL_POLL_MS_ENV)
.ok()
.filter(|v| !v.is_empty())
.and_then(|v| v.trim().parse::<u64>().ok())
.filter(|&n| n > 0)
.unwrap_or(DEFAULT_POLL_INTERVAL_MS);
Duration::from_millis(ms)
}
fn poll_loop(
pids: Vec<libc::pid_t>,
interval: Duration,
shutdown: Arc<AtomicBool>,
reports: Arc<Mutex<Vec<StallReport>>>,
) {
let mut windows: Vec<(libc::pid_t, VecDeque<SchedSample>, bool)> = pids
.iter()
.map(|&p| (p, VecDeque::with_capacity(STALL_WINDOW), true))
.collect();
while !shutdown.load(Ordering::SeqCst) {
for (pid, window, armed) in windows.iter_mut() {
let sample = match read_sched_sample(*pid) {
Some(s) => s,
None => {
window.clear();
continue;
}
};
if process_iteration(sample, window, armed) {
let samples: Vec<SchedSample> = window.iter().copied().collect();
let comm =
read_comm(*pid).unwrap_or_else(|reason| format!("[unreadable: {reason}]"));
let diagnostic = capture_diagnostic(*pid);
let report = StallReport {
pid: *pid,
comm,
samples,
captured_at: Instant::now(),
diagnostic,
};
{
let mut guard = reports
.lock()
.expect("stall-monitor reports mutex poisoned");
guard.push(report);
}
}
}
let chunk = Duration::from_millis(50).min(interval);
let mut remaining = interval;
while remaining > Duration::ZERO && !shutdown.load(Ordering::SeqCst) {
let slice = chunk.min(remaining);
thread::sleep(slice);
remaining = remaining.saturating_sub(slice);
}
}
}
pub(crate) fn process_iteration(
sample: SchedSample,
window: &mut VecDeque<SchedSample>,
armed: &mut bool,
) -> bool {
window.push_back(sample);
while window.len() > STALL_WINDOW {
window.pop_front();
}
if window.len() >= 2 {
let last = window[window.len() - 1];
let prev = window[window.len() - 2];
if last.nr_switches != prev.nr_switches
|| last.sum_exec_runtime_ns != prev.sum_exec_runtime_ns
{
*armed = true;
}
}
if *armed && stall_predicate(window.make_contiguous()) {
*armed = false;
true
} else {
false
}
}
pub fn stall_predicate(samples: &[SchedSample]) -> bool {
if samples.len() < STALL_WINDOW {
return false;
}
for pair in samples.windows(2) {
if pair[0].nr_switches != pair[1].nr_switches
|| pair[0].sum_exec_runtime_ns != pair[1].sum_exec_runtime_ns
{
return false;
}
}
true
}
pub fn parse_sched_file(content: &str) -> Option<(u64, u64)> {
let mut nr_switches: Option<u64> = None;
let mut sum_exec_runtime_ns: Option<u64> = None;
for line in content.lines() {
let Some((key, value)) = line.split_once(':') else {
continue;
};
let key = key.trim();
let value = value.trim();
match key {
"nr_switches" => {
nr_switches = value.parse::<u64>().ok();
}
"se.sum_exec_runtime" => {
sum_exec_runtime_ns = value
.parse::<f64>()
.ok()
.map(|ms| (ms * 1_000_000.0) as u64);
}
_ => {}
}
if nr_switches.is_some() && sum_exec_runtime_ns.is_some() {
break;
}
}
match (nr_switches, sum_exec_runtime_ns) {
(Some(n), Some(r)) => Some((n, r)),
_ => None,
}
}
fn read_sched_sample(pid: libc::pid_t) -> Option<SchedSample> {
let content = std::fs::read_to_string(format!("/proc/{pid}/sched")).ok()?;
let (nr_switches, sum_exec_runtime_ns) = parse_sched_file(&content)?;
Some(SchedSample {
nr_switches,
sum_exec_runtime_ns,
captured_at: Instant::now(),
})
}
fn read_comm(pid: libc::pid_t) -> std::result::Result<String, String> {
std::fs::read_to_string(format!("/proc/{pid}/comm"))
.map(|s| s.trim_end_matches('\n').to_string())
.map_err(|e| e.to_string())
}
fn capture_diagnostic(pid: libc::pid_t) -> StallDiagnostic {
let wchan = read_proc_field(pid, "wchan");
let syscall = read_proc_field(pid, "syscall");
let status_full = read_proc_field(pid, "status");
let state = extract_state_letter(&status_full);
let cgroup = read_proc_field(pid, "cgroup");
let stack = std::fs::read_to_string(format!("/proc/{pid}/stack")).ok();
let host_loadavg = std::fs::read_to_string("/proc/loadavg")
.map(|s| s.trim_end_matches('\n').to_string())
.unwrap_or_else(|e| format!("[unreadable: {e}]"));
StallDiagnostic {
wchan,
syscall,
state,
stack,
status_full,
cgroup,
host_loadavg,
}
}
fn read_proc_field(pid: libc::pid_t, field: &str) -> String {
match std::fs::read_to_string(format!("/proc/{pid}/{field}")) {
Ok(s) => s.trim_end_matches('\n').to_string(),
Err(e) => format!("[unreadable: {e}]"),
}
}
fn extract_state_letter(status: &str) -> String {
for line in status.lines() {
if let Some(rest) = line.strip_prefix("State:") {
let rest = rest.trim();
if let Some(letter) = rest.split_whitespace().next() {
return letter.to_string();
}
}
}
"?".to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_sched_file_extracts_signals() {
let content = "\
worker_0 (12345, #threads: 1)
-------------------------------------------------------------------
se.exec_start : 123456789.123456
se.vruntime : 789.012345
se.sum_exec_runtime : 1234.567890
nr_migrations : 7
nr_switches : 42
nr_voluntary_switches : 30
nr_involuntary_switches : 12
clock-delta : 0
";
let parsed = parse_sched_file(content).expect("both fields present");
assert_eq!(parsed.0, 42, "nr_switches");
assert_eq!(parsed.1, 1_234_567_890, "sum_exec_runtime in ns");
}
#[test]
fn parse_sched_file_handles_live_proc_self_sched() {
let Ok(content) = std::fs::read_to_string("/proc/self/sched") else {
return;
};
let parsed = parse_sched_file(&content)
.expect("live /proc/self/sched MUST parse — kernel-format regression");
assert!(
parsed.0 >= 1,
"live nr_switches must be >= 1, got {}",
parsed.0
);
assert!(
parsed.1 > 0,
"live sum_exec_runtime_ns must be > 0, got {}",
parsed.1
);
}
fn s(nr: u64, ns: u64) -> SchedSample {
SchedSample {
nr_switches: nr,
sum_exec_runtime_ns: ns,
captured_at: Instant::now(),
}
}
#[test]
fn stall_predicate_fires_after_w_samples_no_delta() {
let samples: Vec<SchedSample> = (0..STALL_WINDOW).map(|_| s(100, 5_000)).collect();
assert!(
stall_predicate(&samples),
"all-flat window of W samples must fire"
);
}
#[test]
fn stall_predicate_skips_when_delta_present() {
let samples = vec![s(100, 5_000), s(100, 5_000), s(100, 5_000), s(101, 5_000)];
assert!(
!stall_predicate(&samples),
"any non-zero delta in any consecutive pair must keep predicate false",
);
let samples = vec![s(100, 5_000), s(100, 5_000), s(100, 5_001), s(100, 5_001)];
assert!(
!stall_predicate(&samples),
"exec_runtime delta in any pair must keep predicate false",
);
let short: Vec<SchedSample> = (0..STALL_WINDOW - 1).map(|_| s(100, 5_000)).collect();
assert!(
!stall_predicate(&short),
"window shorter than STALL_WINDOW must not fire (insufficient signal)",
);
}
#[test]
fn diagnostic_capture_skips_unreadable_fields() {
let diag = capture_diagnostic(0);
assert!(
diag.wchan.starts_with("[unreadable:"),
"wchan must degrade: got {:?}",
diag.wchan,
);
assert!(
diag.syscall.starts_with("[unreadable:"),
"syscall must degrade: got {:?}",
diag.syscall,
);
assert!(
diag.status_full.starts_with("[unreadable:"),
"status must degrade: got {:?}",
diag.status_full,
);
assert!(
diag.cgroup.starts_with("[unreadable:"),
"cgroup must degrade: got {:?}",
diag.cgroup,
);
assert_eq!(diag.state, "?", "unreadable status → state = ?");
assert!(diag.stack.is_none(), "missing stack must remain None");
assert!(
!diag.host_loadavg.is_empty(),
"host_loadavg must always populate (success OR stand-in)",
);
}
#[test]
fn ring_buffer_sliding_window_correctness() {
let mut window: VecDeque<SchedSample> = VecDeque::with_capacity(STALL_WINDOW);
for i in 0..(STALL_WINDOW + 2) {
window.push_back(s(i as u64, i as u64 * 10));
while window.len() > STALL_WINDOW {
window.pop_front();
}
}
assert_eq!(
window.len(),
STALL_WINDOW,
"window size must stay at STALL_WINDOW after overflow",
);
let head = window.front().expect("window non-empty");
assert_eq!(
head.nr_switches, 2,
"oldest sample must be index 2 after 2 evictions"
);
let tail = window.back().expect("window non-empty");
assert_eq!(
tail.nr_switches,
(STALL_WINDOW + 1) as u64,
"newest sample must be the last pushed (index W+1)",
);
let snap: Vec<SchedSample> = window.iter().copied().collect();
assert!(
!stall_predicate(&snap),
"monotonic samples must not trip predicate"
);
}
#[test]
fn process_iteration_spawn_gate_short_window_never_fires() {
let mut window: VecDeque<SchedSample> = VecDeque::with_capacity(STALL_WINDOW);
let mut armed = true;
for _ in 0..(STALL_WINDOW - 1) {
assert!(
!process_iteration(s(100, 5_000), &mut window, &mut armed),
"short window must not fire (spawn-gate semantic)",
);
}
assert!(armed, "no resume seen → stays armed");
assert!(
process_iteration(s(100, 5_000), &mut window, &mut armed),
"Wth flat sample fills window AND trips predicate → fire",
);
assert!(!armed, "fire path disarms");
}
#[test]
fn process_iteration_rearm_after_stall_then_resume() {
let mut window: VecDeque<SchedSample> = VecDeque::with_capacity(STALL_WINDOW);
let mut armed = true;
for _ in 0..STALL_WINDOW {
process_iteration(s(100, 5_000), &mut window, &mut armed);
}
assert!(!armed, "after first fire, disarmed");
assert!(
!process_iteration(s(101, 5_001), &mut window, &mut armed),
"resume sample must not fire (last pair has delta)",
);
assert!(armed, "resume sample must re-arm");
let mut second_fire_iter = None;
for i in 0..STALL_WINDOW {
if process_iteration(s(101, 5_001), &mut window, &mut armed) {
second_fire_iter = Some(i);
break;
}
}
assert!(
second_fire_iter.is_some(),
"second stall window must fire after re-arm; got no fire across {} iters",
STALL_WINDOW,
);
assert!(!armed, "second fire disarms");
}
#[test]
fn process_iteration_permanent_stall_fires_only_once() {
let mut window: VecDeque<SchedSample> = VecDeque::with_capacity(STALL_WINDOW);
let mut armed = true;
let mut fire_count = 0;
for _ in 0..STALL_WINDOW {
if process_iteration(s(100, 5_000), &mut window, &mut armed) {
fire_count += 1;
}
}
assert_eq!(fire_count, 1, "first window MUST fire exactly once");
for _ in 0..100 {
if process_iteration(s(100, 5_000), &mut window, &mut armed) {
fire_count += 1;
}
}
assert_eq!(
fire_count, 1,
"permanently-stuck pid must NOT spam reports — exactly one fire across many iters",
);
}
}