use std::os::unix::io::AsRawFd;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Condvar, Mutex};
use serde::{Deserialize, Serialize};
use vmm_sys_util::epoll::{ControlOperation, Epoll, EpollEvent, EventSet};
use vmm_sys_util::eventfd::{EFD_NONBLOCK, EventFd};
use super::PiMutex;
use super::virtio_console::VirtioConsole;
const CAP_OVERFLOW_ERROR_REPLY: &[u8] = b"{\"ktstr_relay_error\":\"response cap overflow\"}\n";
pub const MAX_REQUEST_BYTES: usize = 256 * 1024;
pub const MAX_RESPONSE_BYTES: usize = 256 * 1024;
#[derive(Debug)]
pub enum SchedStatsError {
Poisoned,
RequestTooLarge {
size: usize,
max: usize,
},
ResponseTooLarge {
size: usize,
max: usize,
},
DuringFreeze,
Cancelled,
NoScheduler {
reason: String,
},
SchedulerError {
errno: i32,
args: serde_json::Value,
},
MissingResp {
args: serde_json::Value,
},
}
impl std::fmt::Display for SchedStatsError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Poisoned => write!(f, "scx_stats response buffer mutex was poisoned"),
Self::RequestTooLarge { size, max } => write!(
f,
"scx_stats request size {size} bytes exceeds cap of {max}"
),
Self::ResponseTooLarge { size, max } => write!(
f,
"scx_stats response accumulator grew to {size} bytes (cap {max}) \
without emitting a newline; partial bytes discarded"
),
Self::DuringFreeze => write!(
f,
"scx_stats request rejected: freeze rendezvous active \
(scheduler userspace is paused; responses undefined)"
),
Self::Cancelled => write!(
f,
"scx_stats request cancelled: run-wide kill flag set \
(watchdog fired or shutdown in progress)"
),
Self::NoScheduler { reason } => write!(
f,
"scx_stats relay reports no scheduler available: {reason}"
),
Self::SchedulerError { errno, args } => {
write!(f, "scx_stats scheduler returned errno={errno}: {args}")
}
Self::MissingResp { args } => write!(
f,
"scx_stats response envelope missing \"resp\" key in args: {args}"
),
}
}
}
impl std::error::Error for SchedStatsError {}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StatsRequest {
pub req: String,
#[serde(default, skip_serializing_if = "std::collections::BTreeMap::is_empty")]
pub args: std::collections::BTreeMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StatsResponse {
#[serde(default)]
pub errno: i32,
#[serde(default = "default_args_value")]
pub args: serde_json::Value,
}
fn default_args_value() -> serde_json::Value {
serde_json::Value::Null
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct RelayError {
ktstr_relay_error: String,
}
type StatsResponseBuf = Arc<(Mutex<Vec<u8>>, Condvar)>;
struct ClientShared {
virtio_con: Arc<PiMutex<VirtioConsole>>,
response_buf: StatsResponseBuf,
freeze: Option<Arc<AtomicBool>>,
cancel: Option<Arc<AtomicBool>>,
request_lock: Mutex<()>,
request_in_flight: Arc<AtomicBool>,
discarded_bytes: Arc<AtomicU64>,
kill_drainer: EventFd,
}
impl Drop for ClientShared {
fn drop(&mut self) {
let _ = self.kill_drainer.write(1);
}
}
#[derive(Clone)]
pub struct SchedStatsClient {
shared: Arc<ClientShared>,
}
impl SchedStatsClient {
pub(crate) fn new(
virtio_con: Arc<PiMutex<VirtioConsole>>,
freeze: Option<Arc<AtomicBool>>,
cancel: Option<Arc<AtomicBool>>,
cancel_evt: Option<Arc<EventFd>>,
) -> std::io::Result<Self> {
let response_buf: StatsResponseBuf = Arc::new((Mutex::new(Vec::new()), Condvar::new()));
let request_in_flight = Arc::new(AtomicBool::new(false));
let discarded_bytes = Arc::new(AtomicU64::new(0));
let kill_drainer = EventFd::new(EFD_NONBLOCK)?;
let kill_drainer_for_thread = kill_drainer.try_clone()?;
let stats_tx_evt = virtio_con.lock().stats_tx_evt().try_clone()?;
let shared = Arc::new(ClientShared {
virtio_con: virtio_con.clone(),
response_buf: Arc::clone(&response_buf),
freeze,
cancel,
request_lock: Mutex::new(()),
request_in_flight: Arc::clone(&request_in_flight),
discarded_bytes: Arc::clone(&discarded_bytes),
kill_drainer,
});
let drain_virtio_con = virtio_con;
let drain_response_buf = response_buf;
let drain_request_in_flight = request_in_flight;
let drain_discarded_bytes = discarded_bytes;
std::thread::Builder::new()
.name("ktstr-sched-stats-drain".into())
.spawn(move || {
drainer_loop(
stats_tx_evt,
kill_drainer_for_thread,
cancel_evt,
drain_virtio_con,
drain_response_buf,
drain_request_in_flight,
drain_discarded_bytes,
);
})?;
Ok(Self { shared })
}
pub fn request_raw(&self, line: &str) -> Result<Vec<u8>, SchedStatsError> {
let on_wire_len = line.len().saturating_add(1);
if on_wire_len > MAX_REQUEST_BYTES {
return Err(SchedStatsError::RequestTooLarge {
size: on_wire_len,
max: MAX_REQUEST_BYTES,
});
}
if let Some(flag) = self.shared.freeze.as_ref()
&& flag.load(Ordering::Acquire)
{
return Err(SchedStatsError::DuringFreeze);
}
if let Some(flag) = self.shared.cancel.as_ref()
&& flag.load(Ordering::Acquire)
{
return Err(SchedStatsError::Cancelled);
}
let _request_guard = self
.shared
.request_lock
.lock()
.unwrap_or_else(|e| e.into_inner());
self.shared.request_in_flight.store(true, Ordering::Release);
let _in_flight_guard = InFlightGuard {
flag: &self.shared.request_in_flight,
};
let (lock, cvar) = &*self.shared.response_buf;
{
let mut buf = lock.lock().map_err(|_| SchedStatsError::Poisoned)?;
if !buf.is_empty() {
let stale = buf.len();
let total = self
.shared
.discarded_bytes
.fetch_add(stale as u64, Ordering::Relaxed)
.saturating_add(stale as u64);
tracing::debug!(
stale_bytes = stale,
total_discarded = total,
"scx_stats request_raw: clearing stale response bytes from prior call"
);
buf.clear();
}
}
{
let mut g = self.shared.virtio_con.lock();
let stale_in = g.clear_port2_pending_rx();
if stale_in > 0 {
let total = self
.shared
.discarded_bytes
.fetch_add(stale_in as u64, Ordering::Relaxed)
.saturating_add(stale_in as u64);
tracing::debug!(
stale_pending_rx = stale_in,
total_discarded = total,
"scx_stats request_raw: clearing stale port2_pending_rx \
(prior request abandoned mid-push)"
);
}
g.queue_input_port2(line.as_bytes());
g.queue_input_port2(b"\n");
}
let mut buf = lock.lock().map_err(|_| SchedStatsError::Poisoned)?;
loop {
if let Some(idx) = buf.iter().position(|&b| b == b'\n') {
let mut response = buf.split_off(idx + 1);
std::mem::swap(&mut *buf, &mut response);
response.pop();
if let Ok(err) = serde_json::from_slice::<RelayError>(&response) {
return Err(SchedStatsError::NoScheduler {
reason: err.ktstr_relay_error,
});
}
return Ok(response);
}
if buf.len() > MAX_RESPONSE_BYTES {
let size = buf.len();
buf.clear();
return Err(SchedStatsError::ResponseTooLarge {
size,
max: MAX_RESPONSE_BYTES,
});
}
if let Some(flag) = self.shared.freeze.as_ref()
&& flag.load(Ordering::Acquire)
{
buf.clear();
return Err(SchedStatsError::DuringFreeze);
}
if let Some(flag) = self.shared.cancel.as_ref()
&& flag.load(Ordering::Acquire)
{
buf.clear();
return Err(SchedStatsError::Cancelled);
}
buf = cvar.wait(buf).map_err(|_| SchedStatsError::Poisoned)?;
}
}
pub fn request(&self, request: &StatsRequest) -> Result<StatsResponse, anyhow::Error> {
let line = serde_json::to_string(request)?;
let raw = self.request_raw(&line).map_err(|e| anyhow::anyhow!(e))?;
Ok(serde_json::from_slice::<StatsResponse>(&raw)?)
}
pub fn stats(&self, args: &[(&str, &str)]) -> Result<serde_json::Value, anyhow::Error> {
let mut req = StatsRequest {
req: "stats".to_string(),
args: std::collections::BTreeMap::new(),
};
for (k, v) in args {
req.args.insert((*k).to_string(), (*v).to_string());
}
let resp = self.request(&req)?;
extract_resp(resp)
}
pub fn stats_meta(&self) -> Result<serde_json::Value, anyhow::Error> {
let req = StatsRequest {
req: "stats_meta".to_string(),
args: std::collections::BTreeMap::new(),
};
let resp = self.request(&req)?;
extract_resp(resp)
}
pub fn discarded_bytes(&self) -> u64 {
self.shared.discarded_bytes.load(Ordering::Relaxed)
}
}
fn extract_resp(resp: StatsResponse) -> Result<serde_json::Value, anyhow::Error> {
if resp.errno != 0 {
return Err(anyhow::anyhow!(SchedStatsError::SchedulerError {
errno: resp.errno,
args: resp.args,
}));
}
let serde_json::Value::Object(mut map) = resp.args else {
return Err(anyhow::anyhow!(SchedStatsError::MissingResp {
args: resp.args
}));
};
match map.remove("resp") {
Some(v) => Ok(v),
None => Err(anyhow::anyhow!(SchedStatsError::MissingResp {
args: serde_json::Value::Object(map),
})),
}
}
impl std::fmt::Debug for SchedStatsClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SchedStatsClient")
.field(
"response_buf_len",
&self
.shared
.response_buf
.0
.lock()
.map(|b| b.len())
.unwrap_or(0),
)
.field(
"discarded_bytes",
&self.shared.discarded_bytes.load(Ordering::Relaxed),
)
.field(
"request_in_flight",
&self.shared.request_in_flight.load(Ordering::Relaxed),
)
.finish()
}
}
struct InFlightGuard<'a> {
flag: &'a Arc<AtomicBool>,
}
impl<'a> Drop for InFlightGuard<'a> {
fn drop(&mut self) {
self.flag.store(false, Ordering::Release);
}
}
fn drainer_loop(
stats_tx_evt: EventFd,
kill_drainer: EventFd,
cancel_evt: Option<Arc<EventFd>>,
virtio_con: Arc<PiMutex<VirtioConsole>>,
response_buf: StatsResponseBuf,
request_in_flight: Arc<AtomicBool>,
discarded_bytes: Arc<AtomicU64>,
) {
const TOKEN_DATA: u64 = 0;
const TOKEN_KILL: u64 = 1;
const TOKEN_CANCEL: u64 = 2;
let epoll = match Epoll::new() {
Ok(e) => e,
Err(e) => {
tracing::error!(error = %e, "stats drainer: epoll_create1 failed; aborting drainer");
return;
}
};
for (fd, token, name) in [
(stats_tx_evt.as_raw_fd(), TOKEN_DATA, "stats_tx_evt"),
(kill_drainer.as_raw_fd(), TOKEN_KILL, "kill_drainer"),
] {
if let Err(e) = epoll.ctl(
ControlOperation::Add,
fd,
EpollEvent::new(EventSet::IN, token),
) {
tracing::error!(
error = %e,
fd_name = name,
"stats drainer: epoll_ctl ADD failed; aborting drainer"
);
return;
}
}
if let Some(cancel) = cancel_evt.as_ref()
&& let Err(e) = epoll.ctl(
ControlOperation::Add,
cancel.as_raw_fd(),
EpollEvent::new(EventSet::IN, TOKEN_CANCEL),
)
{
tracing::warn!(
error = %e,
"stats drainer: epoll_ctl ADD on cancel_evt failed; \
cancel edge will not wake blocked requests promptly"
);
}
let mut events_buf = [EpollEvent::default(); 3];
loop {
let event_count = match epoll.wait(-1, &mut events_buf) {
Ok(n) => n,
Err(e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
Err(e) => {
tracing::error!(error = %e, "stats drainer: epoll_wait failed; exiting");
return;
}
};
for ev in &events_buf[..event_count] {
match ev.data() {
TOKEN_KILL => {
let _ = kill_drainer.read();
let (lock, cvar) = &*response_buf;
let _guard = lock.lock();
cvar.notify_all();
return;
}
TOKEN_CANCEL => {
if let Some(c) = cancel_evt.as_ref() {
let _ = c.read();
}
let (lock, cvar) = &*response_buf;
let _guard = lock.lock();
cvar.notify_all();
return;
}
TOKEN_DATA => {
let _ = stats_tx_evt.read();
let bytes = {
let mut g = virtio_con.lock();
g.drain_port2_bulk()
};
if bytes.is_empty() {
continue;
}
if request_in_flight.load(Ordering::Acquire) {
let (lock, cvar) = &*response_buf;
if let Ok(mut guard) = lock.lock() {
let new_total = guard.len().saturating_add(bytes.len());
if new_total > MAX_RESPONSE_BYTES.saturating_mul(2) {
tracing::warn!(
current = guard.len(),
incoming = bytes.len(),
cap = MAX_RESPONSE_BYTES * 2,
"stats drainer: hard cap reached; injecting cap-overflow error envelope"
);
guard.clear();
guard.extend_from_slice(CAP_OVERFLOW_ERROR_REPLY);
cvar.notify_all();
continue;
}
guard.extend_from_slice(&bytes);
cvar.notify_all();
}
} else {
let total = discarded_bytes
.fetch_add(bytes.len() as u64, Ordering::Relaxed)
.saturating_add(bytes.len() as u64);
tracing::debug!(
this_drop = bytes.len(),
total_discarded = total,
"stats drainer: discarding port-2 bytes (no request in flight)"
);
}
}
_ => {}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
fn make_client_full(
freeze: Option<Arc<AtomicBool>>,
cancel: Option<Arc<AtomicBool>>,
cancel_evt: Option<Arc<EventFd>>,
) -> SchedStatsClient {
let virtio_con = Arc::new(PiMutex::new(VirtioConsole::new()));
SchedStatsClient::new(virtio_con, freeze, cancel, cancel_evt).expect("client construct")
}
fn make_client() -> SchedStatsClient {
make_client_full(None, None, None)
}
fn pre_populate(client: &SchedStatsClient, bytes: &[u8]) -> std::thread::JoinHandle<()> {
let buf = Arc::clone(&client.shared.response_buf);
let in_flight = Arc::clone(&client.shared.request_in_flight);
let bytes = bytes.to_vec();
std::thread::spawn(move || {
for _ in 0..200 {
if in_flight.load(Ordering::Acquire) {
break;
}
std::thread::sleep(Duration::from_millis(1));
}
let (lock, cvar) = &*buf;
let mut guard = lock.lock().unwrap();
guard.extend_from_slice(&bytes);
cvar.notify_all();
})
}
#[test]
fn drainer_append_then_request_returns_first_line() {
let client = make_client();
let writer = pre_populate(&client, b"hello\n");
let resp = client
.request_raw("x")
.expect("must wake when bytes arrive");
assert_eq!(resp, b"hello");
writer.join().unwrap();
}
#[test]
fn oversize_request_rejected() {
let client = make_client();
let big = "x".repeat(MAX_REQUEST_BYTES);
let err = client
.request_raw(&big)
.expect_err("must reject oversize request");
match err {
SchedStatsError::RequestTooLarge { size, max } => {
assert_eq!(size, MAX_REQUEST_BYTES + 1);
assert_eq!(max, MAX_REQUEST_BYTES);
}
other => panic!("unexpected error variant: {other:?}"),
}
}
#[test]
fn oversize_response_returns_response_too_large() {
let client = make_client();
let payload: Vec<u8> = std::iter::repeat_n(b'A', MAX_RESPONSE_BYTES + 1).collect();
let writer = pre_populate(&client, &payload);
let err = client
.request_raw("x")
.expect_err("must reject oversize response");
match err {
SchedStatsError::ResponseTooLarge { size, max } => {
assert!(size > MAX_RESPONSE_BYTES);
assert_eq!(max, MAX_RESPONSE_BYTES);
}
other => panic!("unexpected error variant: {other:?}"),
}
writer.join().unwrap();
}
#[test]
fn during_freeze_rejects_request() {
let freeze = Arc::new(AtomicBool::new(true));
let client = make_client_full(Some(freeze.clone()), None, None);
let err = client
.request_raw("x")
.expect_err("must reject during freeze");
assert!(matches!(err, SchedStatsError::DuringFreeze));
freeze.store(false, Ordering::Release);
let writer = pre_populate(&client, b"ok\n");
let resp = client.request_raw("x").expect("must succeed after thaw");
assert_eq!(resp, b"ok");
writer.join().unwrap();
}
#[test]
fn cancel_flag_set_before_request_rejects() {
let cancel = Arc::new(AtomicBool::new(true));
let cancel_evt = Arc::new(EventFd::new(EFD_NONBLOCK).unwrap());
let client = make_client_full(None, Some(cancel), Some(cancel_evt));
let err = client
.request_raw("x")
.expect_err("must reject when cancel pre-set");
assert!(matches!(err, SchedStatsError::Cancelled));
}
#[test]
fn cancel_during_blocked_request_wakes() {
let cancel = Arc::new(AtomicBool::new(false));
let cancel_evt = Arc::new(EventFd::new(EFD_NONBLOCK).unwrap());
let client = make_client_full(None, Some(cancel.clone()), Some(cancel_evt.clone()));
let in_flight = Arc::clone(&client.shared.request_in_flight);
let waker = std::thread::spawn(move || {
for _ in 0..200 {
if in_flight.load(Ordering::Acquire) {
break;
}
std::thread::sleep(Duration::from_millis(1));
}
cancel.store(true, Ordering::Release);
let _ = cancel_evt.write(1);
});
let err = client
.request_raw("x")
.expect_err("must wake on cancel edge");
assert!(matches!(err, SchedStatsError::Cancelled));
waker.join().unwrap();
}
#[test]
fn concurrent_requests_serialise_via_lock() {
let client = make_client();
let buf = Arc::clone(&client.shared.response_buf);
let in_flight = Arc::clone(&client.shared.request_in_flight);
let writer = std::thread::spawn(move || {
for _ in 0..200 {
if in_flight.load(Ordering::Acquire) {
let (lock, cvar) = &*buf;
let mut guard = lock.lock().unwrap();
guard.extend_from_slice(b"first\n");
cvar.notify_all();
break;
}
std::thread::sleep(Duration::from_millis(1));
}
for _ in 0..200 {
if !in_flight.load(Ordering::Acquire) {
break;
}
std::thread::sleep(Duration::from_millis(1));
}
for _ in 0..200 {
if in_flight.load(Ordering::Acquire) {
let (lock, cvar) = &*buf;
let mut guard = lock.lock().unwrap();
if guard.is_empty() {
guard.extend_from_slice(b"second\n");
cvar.notify_all();
break;
}
}
std::thread::sleep(Duration::from_millis(1));
}
});
let c2 = client.clone();
let h = std::thread::spawn(move || c2.request_raw("a").expect("first request succeeds"));
std::thread::sleep(Duration::from_millis(20));
let resp_b = client
.request_raw("b")
.expect("second request succeeds (after lock release)");
let resp_a = h.join().expect("first thread joins");
let mut got = vec![resp_a, resp_b];
got.sort();
assert_eq!(got, vec![b"first".to_vec(), b"second".to_vec()]);
writer.join().unwrap();
}
#[test]
fn relay_error_envelope_surfaces_as_no_scheduler() {
let client = make_client();
let writer = pre_populate(
&client,
br#"{"ktstr_relay_error":"no scheduler running"}
"#,
);
let err = client
.request_raw("x")
.expect_err("must surface NoScheduler");
match err {
SchedStatsError::NoScheduler { reason } => {
assert_eq!(reason, "no scheduler running");
}
other => panic!("unexpected error variant: {other:?}"),
}
writer.join().unwrap();
}
#[test]
fn typed_request_round_trip() {
let req = StatsRequest {
req: "stats".to_string(),
args: std::collections::BTreeMap::new(),
};
let bytes = serde_json::to_vec(&req).unwrap();
let decoded: StatsRequest = serde_json::from_slice(&bytes).unwrap();
assert_eq!(decoded.req, "stats");
assert!(decoded.args.is_empty());
let mut req_args = StatsRequest {
req: "stats".to_string(),
args: std::collections::BTreeMap::new(),
};
req_args
.args
.insert("target".to_string(), "top".to_string());
let bytes_args = serde_json::to_vec(&req_args).unwrap();
assert_eq!(
std::str::from_utf8(&bytes_args).unwrap(),
r#"{"req":"stats","args":{"target":"top"}}"#,
);
let decoded_args: StatsRequest = serde_json::from_slice(&bytes_args).unwrap();
assert_eq!(decoded_args.args.get("target"), Some(&"top".to_string()));
let resp_wire = br#"{"errno":0,"args":{"resp":{"foo":42}}}"#;
let resp: StatsResponse = serde_json::from_slice(resp_wire).unwrap();
assert_eq!(resp.errno, 0);
assert_eq!(resp.args["resp"]["foo"], 42);
}
#[test]
fn extract_resp_happy_path_and_errno() {
let resp_ok = StatsResponse {
errno: 0,
args: serde_json::json!({"resp": {"counter": 7}}),
};
let payload = extract_resp(resp_ok).unwrap();
assert_eq!(payload["counter"], 7);
let resp_err = StatsResponse {
errno: 22,
args: serde_json::json!({"message": "EINVAL"}),
};
let err = extract_resp(resp_err).unwrap_err();
let downcast = err.downcast_ref::<SchedStatsError>().expect("downcast");
match downcast {
SchedStatsError::SchedulerError { errno, args } => {
assert_eq!(*errno, 22);
assert_eq!(args["message"], "EINVAL");
}
other => panic!("unexpected error variant: {other:?}"),
}
}
#[test]
fn extract_resp_missing_resp_key() {
let resp = StatsResponse {
errno: 0,
args: serde_json::json!({"other": 1}),
};
let err = extract_resp(resp).unwrap_err();
let downcast = err.downcast_ref::<SchedStatsError>().expect("downcast");
assert!(matches!(downcast, SchedStatsError::MissingResp { .. }));
}
}