use crate::sync::MutexExt;
use crate::vmm::wire::{
KERNEL_OP_REPLY_MAX, KernelOpReplyPayload, KernelOpRequestPayload, KernelOpRequestResult,
LifecyclePhase, MSG_TYPE_KERNEL_OP_REPLY, MSG_TYPE_SNAPSHOT_REPLY, MsgType,
SNAPSHOT_REASON_MAX, SNAPSHOT_STATUS_ERR, SNAPSHOT_STATUS_OK, SNAPSHOT_TAG_MAX, ShmMessage,
SnapshotReplyPayload, SnapshotRequestPayload, SnapshotRequestResult,
};
use zerocopy::{FromBytes, IntoBytes};
pub static GUEST_WRITE_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
pub fn is_guest() -> bool {
#[cfg(test)]
{
if let Some(v) = IS_GUEST_TEST_OVERRIDE.with(|c| c.get()) {
return v;
}
}
static IS_GUEST: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
*IS_GUEST.get_or_init(|| {
std::fs::read_to_string("/proc/cmdline")
.ok()
.is_some_and(|c| c.split_whitespace().any(|tok| tok == "KTSTR_GUEST=1"))
})
}
#[cfg(test)]
thread_local! {
static IS_GUEST_TEST_OVERRIDE: std::cell::Cell<Option<bool>> = const { std::cell::Cell::new(None) };
}
#[cfg(test)]
pub(crate) struct IsGuestOverrideGuard {
prev: Option<bool>,
}
#[cfg(test)]
impl IsGuestOverrideGuard {
pub(crate) fn new(value: bool) -> Self {
let prev = IS_GUEST_TEST_OVERRIDE.with(|c| c.replace(Some(value)));
Self { prev }
}
}
#[cfg(test)]
impl Drop for IsGuestOverrideGuard {
fn drop(&mut self) {
let prev = self.prev;
IS_GUEST_TEST_OVERRIDE.with(|c| c.set(prev));
}
}
fn assert_guest_context(fn_name: &str, msg_type: u32) -> bool {
if !is_guest() {
tracing::warn!(
msg_type = msg_type,
"guest_comms::{fn_name} called from host context"
);
return false;
}
true
}
pub(crate) const BULK_PORT_DEV: &str = "/dev/vport0p1";
static BULK_PORT_FD: std::sync::OnceLock<std::sync::Mutex<Option<std::fs::File>>> =
std::sync::OnceLock::new();
#[cfg(test)]
pub(crate) static BULK_PORT_WRITE_ATTEMPTS: std::sync::atomic::AtomicU64 =
std::sync::atomic::AtomicU64::new(0);
fn try_open_bulk_port() -> Option<std::fs::File> {
std::fs::OpenOptions::new()
.read(true)
.write(true)
.open(BULK_PORT_DEV)
.ok()
}
fn write_msg(msg_type: u32, payload: &[u8]) -> bool {
if !assert_guest_context("write_msg", msg_type) {
return false;
}
let _guard = GUEST_WRITE_LOCK.lock_unpoisoned();
write_to_bulk_port(msg_type, payload)
}
fn write_to_bulk_port(msg_type: u32, payload: &[u8]) -> bool {
#[cfg(test)]
BULK_PORT_WRITE_ATTEMPTS.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
let slot = BULK_PORT_FD.get_or_init(|| std::sync::Mutex::new(None));
let mut guard = slot.lock_unpoisoned();
if guard.is_none() {
match try_open_bulk_port() {
Some(f) => *guard = Some(f),
None => return false,
}
}
let f = guard.as_mut().expect("bulk port handle just installed");
let Ok(length_u32) = u32::try_from(payload.len()) else {
tracing::warn!(
len = payload.len(),
msg_type,
"write_to_bulk_port: payload exceeds u32::MAX; dropping"
);
return false;
};
let msg = ShmMessage {
msg_type,
length: length_u32,
crc32: crc32fast::hash(payload),
_pad: 0,
};
let header_bytes = msg.as_bytes();
let total = header_bytes.len() + payload.len();
let fd = std::os::unix::io::AsRawFd::as_raw_fd(f);
let mut iovs = [
std::io::IoSlice::new(header_bytes),
std::io::IoSlice::new(payload),
];
let mut bufs: &mut [std::io::IoSlice<'_>] = &mut iovs[..];
let mut written: usize = 0;
while !bufs.is_empty() {
let r = unsafe {
libc::writev(
fd,
bufs.as_ptr() as *const libc::iovec,
bufs.len() as libc::c_int,
)
};
if r < 0 {
let err = std::io::Error::last_os_error();
if err.kind() == std::io::ErrorKind::Interrupted {
continue;
}
tracing::warn!(
%err,
msg_type,
len = payload.len(),
"write_to_bulk_port: writev failed"
);
*guard = None;
return false;
}
if r == 0 {
tracing::warn!(
msg_type,
len = payload.len(),
written,
total,
"write_to_bulk_port: writev returned 0"
);
*guard = None;
return false;
}
let n = r as usize;
written += n;
std::io::IoSlice::advance_slices(&mut bufs, n);
}
debug_assert_eq!(written, total);
true
}
pub fn send_exit(code: i32) {
write_msg(MsgType::Exit.wire_value(), &code.to_le_bytes());
}
#[derive(Debug)]
enum TestResultWire {
Raw(Vec<u8>),
Stripped { bytes: Vec<u8>, dropped: usize },
Truncated { offending: usize },
}
fn classify_test_result(
result: &crate::assert::AssertResult,
max: usize,
) -> Option<TestResultWire> {
let bytes = postcard::to_stdvec(result).ok()?;
if bytes.len() <= max {
return Some(TestResultWire::Raw(bytes));
}
let mut stripped = result.clone();
let dropped = stripped.strip_phase_cgroup_samples();
let mut sample_free_size: Option<usize> = None;
if dropped > 0 {
stripped.note(format!(
"per_cgroup raw samples ({dropped}) dropped: AssertResult postcard size {} \
exceeded bulk port limit {max}; verdict and reduced telemetry preserved",
bytes.len(),
));
if let Ok(small) = postcard::to_stdvec(&stripped) {
sample_free_size = Some(small.len());
if small.len() <= max {
return Some(TestResultWire::Stripped {
bytes: small,
dropped,
});
}
}
}
Some(TestResultWire::Truncated {
offending: sample_free_size.unwrap_or(bytes.len()),
})
}
pub fn send_test_result(result: &crate::assert::AssertResult) {
let max = crate::vmm::bulk::MAX_BULK_FRAME_PAYLOAD as usize;
match classify_test_result(result, max) {
Some(TestResultWire::Raw(bytes)) => {
write_msg(MsgType::TestResult.wire_value(), &bytes);
}
Some(TestResultWire::Stripped { bytes, dropped }) => {
tracing::warn!(
stripped = bytes.len(),
dropped_samples = dropped,
max,
"AssertResult exceeded bulk frame; dropped per_cgroup raw samples, verdict preserved"
);
write_msg(MsgType::TestResult.wire_value(), &bytes);
}
Some(TestResultWire::Truncated { offending }) => {
tracing::error!(
offending_size = offending,
max,
"AssertResult exceeds bulk port frame limit even after dropping samples, sending truncated verdict"
);
let truncated = crate::assert::AssertResult::fail(crate::assert::AssertDetail::new(
crate::assert::DetailKind::Other,
format!(
"AssertResult postcard size {offending} exceeded bulk port limit {max}; \
original details dropped",
),
));
if let Ok(small) = postcard::to_stdvec(&truncated) {
write_msg(MsgType::TestResult.wire_value(), &small);
}
}
None => {
eprintln!("ktstr: postcard-encode AssertResult for bulk-port emit failed");
}
}
}
pub fn send_payload_metrics(metrics: &crate::test_support::PayloadMetrics) -> bool {
match postcard::to_stdvec(metrics) {
Ok(bytes) => write_msg(MsgType::PayloadMetrics.wire_value(), &bytes),
Err(e) => {
eprintln!("ktstr: postcard-encode PayloadMetrics for bulk-port emit: {e}");
false
}
}
}
#[cfg(any(test, coverage))]
pub fn send_profraw(buf: &[u8]) {
write_msg(MsgType::Profraw.wire_value(), buf);
}
#[cfg(feature = "wprof")]
pub fn send_wprof_trace(buf: &[u8]) {
write_msg(MsgType::WprofTrace.wire_value(), buf);
}
pub fn send_stimulus(payload: &[u8]) {
write_msg(MsgType::Stimulus.wire_value(), payload);
}
pub fn send_step_end(payload: &[u8]) {
write_msg(MsgType::StepEnd.wire_value(), payload);
}
pub(crate) fn send_raw_payload_output(raw: &crate::test_support::RawPayloadOutput) {
match postcard::to_stdvec(raw) {
Ok(bytes) => {
write_msg(MsgType::RawPayloadOutput.wire_value(), &bytes);
}
Err(e) => {
eprintln!("ktstr: postcard-encode RawPayloadOutput for bulk-port emit: {e}");
}
}
}
pub fn send_sched_exit(code: i32) {
write_msg(MsgType::SchedExit.wire_value(), &code.to_le_bytes());
}
pub fn send_scenario_start() {
for attempt in 0..5 {
if write_msg(MsgType::ScenarioStart.wire_value(), &[]) {
return;
}
if attempt + 1 < 5 {
std::thread::sleep(std::time::Duration::from_millis(100));
}
}
tracing::warn!(
"send_scenario_start: 5 retries failed — bulk port write never \
succeeded; periodic captures will see scenario_anchor=0 and \
silently 0-fire"
);
}
pub fn send_scenario_end(elapsed_ms: u64, total_iterations: u64) {
let mut payload = [0u8; crate::vmm::wire::SCENARIO_END_PAYLOAD_SIZE];
payload[0..8].copy_from_slice(&elapsed_ms.to_le_bytes());
payload[8..16].copy_from_slice(&total_iterations.to_le_bytes());
write_msg(MsgType::ScenarioEnd.wire_value(), &payload);
}
pub fn send_scenario_pause() {
write_msg(MsgType::ScenarioPause.wire_value(), &[]);
}
pub fn send_scenario_resume() {
write_msg(MsgType::ScenarioResume.wire_value(), &[]);
}
pub fn send_sys_rdy() -> bool {
write_msg(MsgType::SysRdy.wire_value(), &[])
}
pub fn send_kern_addrs(addrs: &super::wire::KernAddrs) -> bool {
let payload = addrs.to_payload();
write_msg(super::wire::MSG_TYPE_KERN_ADDRS, &payload)
}
pub fn read_kernel_text_from_kallsyms() -> Option<u64> {
read_kallsyms_symbol_kva("_text", &["T", "t"])
}
pub fn read_kernel_page_offset_base_from_kallsyms() -> Option<u64> {
read_kallsyms_symbol_kva("page_offset_base", &["D", "d"])
}
fn read_kallsyms_symbol_kva(name: &str, allowed_types: &[&str]) -> Option<u64> {
let kallsyms = std::fs::read_to_string("/proc/kallsyms").ok()?;
for line in kallsyms.lines() {
let mut parts = line.split_ascii_whitespace();
let addr = parts.next()?;
let typ = parts.next()?;
let sym = parts.next()?;
if sym == name && allowed_types.contains(&typ) {
let kva = u64::from_str_radix(addr, 16).ok()?;
if kva != 0 {
return Some(kva);
}
}
}
None
}
pub fn read_phys_base_from_iomem() -> Option<u64> {
let iomem = std::fs::read_to_string("/proc/iomem").ok()?;
#[cfg(target_arch = "x86_64")]
{
for line in iomem.lines() {
let line = line.trim();
if line.ends_with(": Kernel code") {
let range = line.split(':').next()?.trim();
let start = range.split('-').next()?.trim();
let phys_load = u64::from_str_radix(start, 16).ok()?;
return Some(phys_load.wrapping_sub(0x100_0000));
}
}
None
}
#[cfg(target_arch = "aarch64")]
{
let mut ram_start: Option<u64> = None;
let mut code_start: Option<u64> = None;
for line in iomem.lines() {
let line = line.trim();
if ram_start.is_none() && line.ends_with(": System RAM") {
let range = line.split(':').next()?.trim();
let start = range.split('-').next()?.trim();
ram_start = Some(u64::from_str_radix(start, 16).ok()?);
}
if line.ends_with(": Kernel code") {
let range = line.split(':').next()?.trim();
let start = range.split('-').next()?.trim();
code_start = Some(u64::from_str_radix(start, 16).ok()?);
}
}
Some(code_start?.wrapping_sub(ram_start?))
}
}
pub fn send_stdout_chunk(buf: &[u8]) -> bool {
write_msg(MsgType::Stdout.wire_value(), buf)
}
pub fn send_stderr_chunk(buf: &[u8]) -> bool {
write_msg(MsgType::Stderr.wire_value(), buf)
}
pub fn send_sched_log(buf: &[u8]) {
write_msg(MsgType::SchedLog.wire_value(), buf);
}
pub fn send_lifecycle(phase: LifecyclePhase, reason: &str) {
let mut buf = Vec::with_capacity(1 + reason.len());
buf.push(phase.wire_value());
buf.extend_from_slice(reason.as_bytes());
write_msg(MsgType::Lifecycle.wire_value(), &buf);
}
pub fn send_exec_exit(code: i32) {
write_msg(MsgType::ExecExit.wire_value(), &code.to_le_bytes());
}
pub fn send_dmesg(buf: &[u8]) {
write_msg(MsgType::Dmesg.wire_value(), buf);
}
#[allow(dead_code)]
pub fn send_probe_output(buf: &[u8]) {
write_msg(MsgType::ProbeOutput.wire_value(), buf);
}
static SNAPSHOT_REQUEST_COUNTER: std::sync::atomic::AtomicU32 =
std::sync::atomic::AtomicU32::new(1);
static SNAPSHOT_REQUEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
static KERNEL_OP_REQUEST_COUNTER: std::sync::atomic::AtomicU32 =
std::sync::atomic::AtomicU32::new(1);
const SNAPSHOT_FAST_POLL_ITERS: u32 = 4;
const SNAPSHOT_FAST_POLL_INTERVAL: std::time::Duration = std::time::Duration::from_micros(100);
const SNAPSHOT_SLOW_POLL_INTERVAL: std::time::Duration = std::time::Duration::from_millis(5);
fn bounded_read_exact(
f: &mut std::fs::File,
buf: &mut [u8],
deadline: std::time::Instant,
) -> std::io::Result<()> {
use std::io::Read;
use std::os::unix::io::AsRawFd;
let fd = f.as_raw_fd();
let mut filled = 0usize;
let mut iter: u32 = 0;
while filled < buf.len() {
let now = std::time::Instant::now();
if now >= deadline {
return Err(std::io::Error::new(
std::io::ErrorKind::TimedOut,
format!(
"snapshot reply deadline elapsed after reading {filled} of {} header/payload bytes",
buf.len()
),
));
}
let remaining = deadline - now;
let interval = if iter < SNAPSHOT_FAST_POLL_ITERS {
SNAPSHOT_FAST_POLL_INTERVAL
} else {
SNAPSHOT_SLOW_POLL_INTERVAL
};
let slice = remaining.min(interval);
let ts = libc::timespec {
tv_sec: slice.as_secs() as libc::time_t,
tv_nsec: slice.subsec_nanos() as libc::c_long,
};
let mut pfd = libc::pollfd {
fd,
events: libc::POLLIN,
revents: 0,
};
let pr = unsafe { libc::ppoll(&mut pfd, 1, &ts, std::ptr::null()) };
iter = iter.saturating_add(1);
if pr < 0 {
let err = std::io::Error::last_os_error();
if err.kind() == std::io::ErrorKind::Interrupted {
continue;
}
return Err(err);
}
if pr == 0 {
continue;
}
match f.read(&mut buf[filled..]) {
Ok(0) => {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
format!(
"snapshot reply read returned 0 after {filled} of {} bytes",
buf.len()
),
));
}
Ok(n) => {
filled += n;
}
Err(e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
}
}
Ok(())
}
fn read_bulk_port_frame(
f: &mut std::fs::File,
max_payload_size: usize,
deadline: std::time::Instant,
) -> std::io::Result<(u32, Vec<u8>)> {
let mut header = [0u8; std::mem::size_of::<ShmMessage>()];
bounded_read_exact(f, &mut header, deadline)?;
let msg = ShmMessage::read_from_bytes(&header).map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"ShmMessage::read_from_bytes failed (header underflow)",
)
})?;
let length = msg.length as usize;
if length > max_payload_size {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"TLV length {length} exceeds max payload {max_payload_size} for port-1 RX; \
rejecting before allocation to avoid guest OOM"
),
));
}
let mut payload = vec![0u8; length];
if length > 0 {
bounded_read_exact(f, &mut payload, deadline)?;
}
let computed = crc32fast::hash(&payload);
if computed != msg.crc32 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"TLV CRC mismatch: header crc=0x{:08x} computed=0x{computed:08x} length={length}",
msg.crc32
),
));
}
Ok((msg.msg_type, payload))
}
pub fn request_snapshot(
kind: u32,
tag: &str,
timeout: std::time::Duration,
) -> SnapshotRequestResult {
if !is_guest() {
return SnapshotRequestResult::TransportError {
reason: "request_snapshot called from host context (virtio-console port 1 \
is reachable only from inside the guest)"
.into(),
};
}
let _guard = SNAPSHOT_REQUEST_LOCK.lock_unpoisoned();
let mut request_id = SNAPSHOT_REQUEST_COUNTER.fetch_add(1, std::sync::atomic::Ordering::AcqRel);
if request_id == 0 {
request_id = SNAPSHOT_REQUEST_COUNTER.fetch_add(1, std::sync::atomic::Ordering::AcqRel);
}
let tag_bytes = tag.as_bytes();
let tag_len = tag_bytes.len().min(SNAPSHOT_TAG_MAX);
let mut tag_buf = [0u8; SNAPSHOT_TAG_MAX];
tag_buf[..tag_len].copy_from_slice(&tag_bytes[..tag_len]);
let payload = SnapshotRequestPayload {
request_id,
kind,
tag: tag_buf,
};
let bytes = payload.as_bytes();
write_msg(MsgType::SnapshotRequest.wire_value(), bytes);
let read_slot = BULK_PORT_FD.get_or_init(|| std::sync::Mutex::new(None));
let mut read_guard = read_slot.lock_unpoisoned();
if read_guard.is_none() {
match try_open_bulk_port() {
Some(f) => *read_guard = Some(f),
None => {
return SnapshotRequestResult::TransportError {
reason: "/dev/vport0p1 not yet open \
(multiport handshake still in flight)"
.into(),
};
}
}
}
let f = read_guard
.as_mut()
.expect("bulk port handle just installed");
let deadline = std::time::Instant::now() + timeout;
loop {
let now = std::time::Instant::now();
if now >= deadline {
return SnapshotRequestResult::TransportError {
reason: format!(
"host did not deliver matching snapshot reply within {timeout:?} \
(request_id={request_id}, kind={kind})"
),
};
}
let frame =
match read_bulk_port_frame(f, std::mem::size_of::<SnapshotReplyPayload>(), deadline) {
Ok(frame) => frame,
Err(e) if e.kind() == std::io::ErrorKind::TimedOut => {
return SnapshotRequestResult::TransportError {
reason: format!(
"snapshot reply deadline elapsed before frame complete \
(request_id={request_id}, kind={kind}): {e}"
),
};
}
Err(e) => {
*read_guard = None;
return SnapshotRequestResult::TransportError {
reason: format!(
"snapshot reply read failed (request_id={request_id}): {e}"
),
};
}
};
let (msg_type, frame_payload) = frame;
if msg_type != MSG_TYPE_SNAPSHOT_REPLY {
tracing::warn!(
msg_type,
len = frame_payload.len(),
request_id,
"request_snapshot: ignoring unexpected TLV on port 1 RX (only \
SnapshotReply is expected on this transport in current protocol)"
);
continue;
}
if frame_payload.len() != std::mem::size_of::<SnapshotReplyPayload>() {
tracing::warn!(
request_id,
got = frame_payload.len(),
want = std::mem::size_of::<SnapshotReplyPayload>(),
"request_snapshot: malformed reply payload size; ignoring"
);
continue;
}
let reply = match SnapshotReplyPayload::read_from_bytes(&frame_payload) {
Ok(r) => r,
Err(_) => {
tracing::warn!(
request_id,
"request_snapshot: SnapshotReplyPayload::read_from_bytes failed; ignoring"
);
continue;
}
};
if reply.request_id != request_id {
tracing::warn!(
expected = request_id,
got = reply.request_id,
"request_snapshot: stale reply id (likely a leftover from a prior \
request that timed out on the guest side); ignoring"
);
continue;
}
return match reply.status {
SNAPSHOT_STATUS_OK => SnapshotRequestResult::Ok,
SNAPSHOT_STATUS_ERR => {
let len = reply
.reason
.iter()
.position(|&b| b == 0)
.unwrap_or(SNAPSHOT_REASON_MAX);
let reason = String::from_utf8_lossy(&reply.reason[..len]).to_string();
SnapshotRequestResult::HostError { reason }
}
other => SnapshotRequestResult::TransportError {
reason: format!(
"host reply with unknown status {other} \
(expected OK={SNAPSHOT_STATUS_OK} or ERR={SNAPSHOT_STATUS_ERR})"
),
},
};
}
}
pub fn request_kernel_op(
request: KernelOpRequestPayload,
timeout: std::time::Duration,
) -> KernelOpRequestResult {
if !is_guest() {
return KernelOpRequestResult::TransportError {
reason: "request_kernel_op called from host context (virtio-console port 1 \
is reachable only from inside the guest)"
.into(),
};
}
let _guard = SNAPSHOT_REQUEST_LOCK.lock_unpoisoned();
let mut request_id =
KERNEL_OP_REQUEST_COUNTER.fetch_add(1, std::sync::atomic::Ordering::AcqRel);
if request_id == 0 {
request_id = KERNEL_OP_REQUEST_COUNTER.fetch_add(1, std::sync::atomic::Ordering::AcqRel);
}
let stamped = KernelOpRequestPayload {
request_id,
..request
};
let payload_bytes = match postcard::to_allocvec(&stamped) {
Ok(b) => b,
Err(e) => {
return KernelOpRequestResult::TransportError {
reason: format!(
"request_kernel_op: postcard encode failed (request_id={request_id}): {e}"
),
};
}
};
write_msg(MsgType::KernelOpRequest.wire_value(), &payload_bytes);
let read_slot = BULK_PORT_FD.get_or_init(|| std::sync::Mutex::new(None));
let mut read_guard = read_slot.lock_unpoisoned();
if read_guard.is_none() {
match try_open_bulk_port() {
Some(f) => *read_guard = Some(f),
None => {
return KernelOpRequestResult::TransportError {
reason: "/dev/vport0p1 not yet open \
(multiport handshake still in flight)"
.into(),
};
}
}
}
let f = read_guard
.as_mut()
.expect("bulk port handle just installed");
let deadline = std::time::Instant::now() + timeout;
loop {
let now = std::time::Instant::now();
if now >= deadline {
return KernelOpRequestResult::TransportError {
reason: format!(
"host did not deliver matching kernel-op reply within {timeout:?} \
(request_id={request_id})"
),
};
}
let frame = match read_bulk_port_frame(f, KERNEL_OP_REPLY_MAX, deadline) {
Ok(frame) => frame,
Err(e) if e.kind() == std::io::ErrorKind::TimedOut => {
return KernelOpRequestResult::TransportError {
reason: format!(
"kernel-op reply deadline elapsed before frame complete \
(request_id={request_id}): {e}"
),
};
}
Err(e) => {
*read_guard = None;
return KernelOpRequestResult::TransportError {
reason: format!("kernel-op reply read failed (request_id={request_id}): {e}"),
};
}
};
let (msg_type, frame_payload) = frame;
if msg_type != MSG_TYPE_KERNEL_OP_REPLY {
tracing::warn!(
msg_type,
len = frame_payload.len(),
request_id,
"request_kernel_op: ignoring non-KernelOpReply TLV on port 1 RX (likely a \
stale snapshot reply from a prior request that timed out on the guest side)"
);
continue;
}
let reply: KernelOpReplyPayload = match postcard::from_bytes(&frame_payload) {
Ok(r) => r,
Err(e) => {
tracing::warn!(
request_id,
error = %e,
"request_kernel_op: postcard decode failed; ignoring"
);
continue;
}
};
if reply.request_id != request_id {
tracing::warn!(
expected = request_id,
got = reply.request_id,
"request_kernel_op: stale reply id (likely a leftover from a prior \
request that timed out on the guest side); ignoring"
);
continue;
}
return KernelOpRequestResult::Ok(reply);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_no_bulk_write(label: &str, f: impl FnOnce()) {
use std::sync::atomic::Ordering;
let before = BULK_PORT_WRITE_ATTEMPTS.load(Ordering::SeqCst);
f();
let after = BULK_PORT_WRITE_ATTEMPTS.load(Ordering::SeqCst);
assert_eq!(
after, before,
"{label}: host-context call must NOT reach write_to_bulk_port; \
the is_guest() gate failed to suppress the write \
(before={before}, after={after})",
);
}
#[test]
fn send_exit_from_host_context_is_noop() {
let _g = IsGuestOverrideGuard::new(false);
assert_no_bulk_write("send_exit(0)", || send_exit(0));
assert_no_bulk_write("send_exit(-1)", || send_exit(-1));
}
#[test]
fn send_test_result_from_host_context_is_noop() {
let _g = IsGuestOverrideGuard::new(false);
assert_no_bulk_write("send_test_result", || {
send_test_result(&crate::assert::AssertResult::pass())
});
}
#[test]
fn send_payload_metrics_from_host_context_is_noop() {
let _g = IsGuestOverrideGuard::new(false);
let pm = crate::test_support::PayloadMetrics {
payload_index: 0,
metrics: vec![],
exit_code: 0,
};
assert!(
!send_payload_metrics(&pm),
"host-context send must return false (no frame written)"
);
}
#[test]
fn send_profraw_from_host_context_is_noop() {
let _g = IsGuestOverrideGuard::new(false);
assert_no_bulk_write("send_profraw", || send_profraw(b"\x01\x02\x03"));
}
#[test]
fn send_stimulus_from_host_context_is_noop() {
let _g = IsGuestOverrideGuard::new(false);
assert_no_bulk_write("send_stimulus", || send_stimulus(&[0u8; 24]));
}
#[test]
fn send_raw_payload_output_from_host_context_is_noop() {
let _g = IsGuestOverrideGuard::new(false);
let raw = crate::test_support::RawPayloadOutput {
payload_index: 0,
stdout: String::new(),
stderr: String::new(),
hint: None,
metric_hints: vec![],
metric_bounds: None,
};
assert_no_bulk_write("send_raw_payload_output", || send_raw_payload_output(&raw));
}
#[test]
fn send_sched_exit_from_host_context_is_noop() {
let _g = IsGuestOverrideGuard::new(false);
assert_no_bulk_write("send_sched_exit(0)", || send_sched_exit(0));
assert_no_bulk_write("send_sched_exit(-1)", || send_sched_exit(-1));
}
#[test]
fn send_scenario_start_from_host_context_is_noop() {
let _g = IsGuestOverrideGuard::new(false);
assert_no_bulk_write("send_scenario_start", send_scenario_start);
}
#[test]
fn send_scenario_end_from_host_context_is_noop() {
let _g = IsGuestOverrideGuard::new(false);
assert_no_bulk_write("send_scenario_end(0,0)", || send_scenario_end(0, 0));
assert_no_bulk_write("send_scenario_end(MAX,MAX)", || {
send_scenario_end(u64::MAX, u64::MAX)
});
}
#[test]
fn send_sys_rdy_from_host_context_returns_false() {
let _g = IsGuestOverrideGuard::new(false);
assert!(
!send_sys_rdy(),
"host-context call must return false so the guest's \
retry loop can distinguish 'wrote' from 'noop'"
);
}
#[test]
fn send_stdout_chunk_from_host_context_returns_false() {
let _g = IsGuestOverrideGuard::new(false);
assert!(!send_stdout_chunk(b"hello"));
}
#[test]
fn send_stderr_chunk_from_host_context_returns_false() {
let _g = IsGuestOverrideGuard::new(false);
assert!(!send_stderr_chunk(b"oops"));
}
#[test]
fn send_sched_log_from_host_context_is_noop() {
let _g = IsGuestOverrideGuard::new(false);
assert_no_bulk_write("send_sched_log", || {
send_sched_log(b"---SCHED_OUTPUT_START---\n")
});
}
#[test]
fn send_lifecycle_from_host_context_is_noop() {
let _g = IsGuestOverrideGuard::new(false);
assert_no_bulk_write("send_lifecycle(InitStarted)", || {
send_lifecycle(LifecyclePhase::InitStarted, "")
});
assert_no_bulk_write("send_lifecycle(PayloadStarting)", || {
send_lifecycle(LifecyclePhase::PayloadStarting, "")
});
assert_no_bulk_write("send_lifecycle(SchedulerDied)", || {
send_lifecycle(LifecyclePhase::SchedulerDied, "")
});
assert_no_bulk_write("send_lifecycle(SchedulerNotAttached)", || {
send_lifecycle(LifecyclePhase::SchedulerNotAttached, "verifier rejected")
});
}
#[test]
fn send_exec_exit_from_host_context_is_noop() {
let _g = IsGuestOverrideGuard::new(false);
assert_no_bulk_write("send_exec_exit(0)", || send_exec_exit(0));
assert_no_bulk_write("send_exec_exit(-1)", || send_exec_exit(-1));
}
#[test]
fn send_dmesg_from_host_context_is_noop() {
let _g = IsGuestOverrideGuard::new(false);
assert_no_bulk_write("send_dmesg", || {
send_dmesg(b"[ 0.000000] Linux version 6.16.0\n")
});
}
#[test]
fn send_probe_output_from_host_context_is_noop() {
let _g = IsGuestOverrideGuard::new(false);
assert_no_bulk_write("send_probe_output", || send_probe_output(b"{}\n"));
}
#[test]
fn request_snapshot_from_host_context_returns_transport_error() {
let _g = IsGuestOverrideGuard::new(false);
let r = request_snapshot(0, "tag", std::time::Duration::from_millis(0));
match r {
SnapshotRequestResult::TransportError { .. } => {}
other => panic!("expected TransportError from host context, got {other:?}"),
}
}
#[test]
fn request_kernel_op_from_host_context_returns_transport_error() {
let _g = IsGuestOverrideGuard::new(false);
let request = crate::vmm::wire::KernelOpRequestPayload {
request_id: 0,
mode: crate::vmm::wire::KernelOpMode::Hot,
direction: crate::vmm::wire::KernelOpDirection::Write,
tag: String::new(),
entries: vec![],
};
let r = request_kernel_op(request, std::time::Duration::from_millis(0));
match r {
crate::vmm::wire::KernelOpRequestResult::TransportError { .. } => {}
other => panic!("expected TransportError from host context, got {other:?}"),
}
}
#[test]
fn read_bulk_port_frame_respects_caller_supplied_cap() {
use std::os::unix::io::FromRawFd;
let mut fds = [0i32; 2];
let r = unsafe { libc::pipe(fds.as_mut_ptr()) };
assert_eq!(r, 0, "pipe(2) failed: {}", std::io::Error::last_os_error());
let mut read_end = unsafe { std::fs::File::from_raw_fd(fds[0]) };
let mut write_end = unsafe { std::fs::File::from_raw_fd(fds[1]) };
let header = ShmMessage {
msg_type: MSG_TYPE_KERNEL_OP_REPLY,
length: 200,
crc32: 0,
_pad: 0,
};
use std::io::Write;
write_end
.write_all(header.as_bytes())
.expect("write forged header");
drop(write_end);
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(5);
let err = read_bulk_port_frame(&mut read_end, 100, deadline)
.expect_err("cap=100 must reject length=200");
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
let msg = err.to_string();
assert!(
msg.contains("exceeds max payload 100"),
"error must cite the caller-supplied cap, got: {msg}"
);
}
#[test]
fn read_bulk_port_frame_rejects_oversized_length_before_alloc() {
use std::os::unix::io::FromRawFd;
let mut fds = [0i32; 2];
let r = unsafe { libc::pipe(fds.as_mut_ptr()) };
assert_eq!(r, 0, "pipe(2) failed: {}", std::io::Error::last_os_error());
let mut read_end = unsafe { std::fs::File::from_raw_fd(fds[0]) };
let mut write_end = unsafe { std::fs::File::from_raw_fd(fds[1]) };
let header = ShmMessage {
msg_type: MSG_TYPE_SNAPSHOT_REPLY,
length: u32::MAX,
crc32: 0,
_pad: 0,
};
use std::io::Write;
write_end
.write_all(header.as_bytes())
.expect("write forged header");
drop(write_end);
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(5);
let err = read_bulk_port_frame(
&mut read_end,
std::mem::size_of::<SnapshotReplyPayload>(),
deadline,
)
.expect_err("oversized length must be rejected");
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
let msg = err.to_string();
assert!(
msg.contains("exceeds max payload"),
"error must explain the cap, got: {msg}"
);
}
#[test]
fn read_bulk_port_frame_accepts_exact_max_payload() {
use std::os::unix::io::FromRawFd;
let mut fds = [0i32; 2];
let r = unsafe { libc::pipe(fds.as_mut_ptr()) };
assert_eq!(r, 0, "pipe(2) failed: {}", std::io::Error::last_os_error());
let mut read_end = unsafe { std::fs::File::from_raw_fd(fds[0]) };
let mut write_end = unsafe { std::fs::File::from_raw_fd(fds[1]) };
let payload = vec![0u8; std::mem::size_of::<SnapshotReplyPayload>()];
let header = ShmMessage {
msg_type: MSG_TYPE_SNAPSHOT_REPLY,
length: payload.len() as u32,
crc32: crc32fast::hash(&payload),
_pad: 0,
};
use std::io::Write;
write_end.write_all(header.as_bytes()).expect("header");
write_end.write_all(&payload).expect("payload");
drop(write_end);
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(5);
let (msg_type, body) = read_bulk_port_frame(
&mut read_end,
std::mem::size_of::<SnapshotReplyPayload>(),
deadline,
)
.expect("exact-size payload must succeed");
assert_eq!(msg_type, MSG_TYPE_SNAPSHOT_REPLY);
assert_eq!(body.len(), std::mem::size_of::<SnapshotReplyPayload>());
}
#[test]
fn is_guest_override_round_trips_through_thread_local() {
{
let _g = IsGuestOverrideGuard::new(false);
assert!(!is_guest());
}
{
let _g = IsGuestOverrideGuard::new(true);
assert!(is_guest());
}
}
#[test]
fn is_guest_override_guards_nest_correctly() {
let _outer = IsGuestOverrideGuard::new(true);
assert!(is_guest());
{
let _inner = IsGuestOverrideGuard::new(false);
assert!(!is_guest());
}
assert!(is_guest());
}
#[test]
fn classify_test_result_selects_branch_by_size() {
use crate::assert::{AssertResult, PhaseBucket, PhaseCgroupStats};
let mk = |n: u64| {
let mut pc = std::collections::BTreeMap::new();
pc.insert(
"cg".to_string(),
PhaseCgroupStats {
wake_latencies_ns: (0..n).collect(),
wake_sample_total: n,
..Default::default()
},
);
let mut r = AssertResult::pass();
r.stats.phases = vec![PhaseBucket {
step_index: 1,
label: "Step[0]".to_string(),
start_ms: 0,
end_ms: 1,
sample_count: 0,
metrics: std::collections::BTreeMap::new(),
per_cgroup: pc,
}];
r
};
let r = mk(1000);
let full = postcard::to_stdvec(&r).unwrap().len();
match classify_test_result(&r, full) {
Some(TestResultWire::Raw(b)) => assert_eq!(b.len(), full),
other => panic!("expected Raw, got {other:?}"),
}
let max = full - 1;
match classify_test_result(&r, max) {
Some(TestResultWire::Stripped { bytes, dropped }) => {
assert_eq!(dropped, 1000);
assert!(bytes.len() <= max);
let decoded: AssertResult = postcard::from_bytes(&bytes).unwrap();
assert!(decoded.is_pass(), "verdict PRESERVED — no PASS->FAIL flip");
assert!(
decoded.stats.phases[0].per_cgroup["cg"]
.wake_latencies_ns
.is_empty(),
"only the samples were dropped",
);
}
other => panic!("expected Stripped, got {other:?}"),
}
match classify_test_result(&r, 1) {
Some(TestResultWire::Truncated { offending }) => {
assert!(offending > 1, "offending overran max=1");
assert!(
offending < full,
"offending {offending} is the post-strip size, not the pre-strip original {full}",
);
}
other => panic!("expected Truncated, got {other:?}"),
}
let ns = mk(0);
let ns_full = postcard::to_stdvec(&ns).unwrap().len();
match classify_test_result(&ns, ns_full - 1) {
Some(TestResultWire::Truncated { offending }) => assert_eq!(offending, ns_full),
other => panic!("expected Truncated, got {other:?}"),
}
}
}