use crate::vmm::wire::{
LifecyclePhase, MSG_TYPE_SNAPSHOT_REPLY, MsgType, SNAPSHOT_REASON_MAX, SNAPSHOT_STATUS_ERR,
SNAPSHOT_STATUS_OK, SNAPSHOT_TAG_MAX, ShmMessage, SnapshotReplyPayload, SnapshotRequestPayload,
SnapshotRequestResult,
};
use zerocopy::{FromBytes, IntoBytes};
pub static GUEST_WRITE_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
pub fn is_guest() -> bool {
#[cfg(test)]
{
if let Some(v) = IS_GUEST_TEST_OVERRIDE.with(|c| c.get()) {
return v;
}
}
static IS_GUEST: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
*IS_GUEST.get_or_init(|| {
std::fs::read_to_string("/proc/cmdline")
.ok()
.is_some_and(|c| c.split_whitespace().any(|tok| tok == "KTSTR_GUEST=1"))
})
}
#[cfg(test)]
thread_local! {
static IS_GUEST_TEST_OVERRIDE: std::cell::Cell<Option<bool>> = const { std::cell::Cell::new(None) };
}
#[cfg(test)]
pub(crate) struct IsGuestOverrideGuard {
prev: Option<bool>,
}
#[cfg(test)]
impl IsGuestOverrideGuard {
pub(crate) fn new(value: bool) -> Self {
let prev = IS_GUEST_TEST_OVERRIDE.with(|c| c.replace(Some(value)));
Self { prev }
}
}
#[cfg(test)]
impl Drop for IsGuestOverrideGuard {
fn drop(&mut self) {
let prev = self.prev;
IS_GUEST_TEST_OVERRIDE.with(|c| c.set(prev));
}
}
fn assert_guest_context(fn_name: &str, msg_type: u32) -> bool {
if !is_guest() {
tracing::warn!(
msg_type = msg_type,
"guest_comms::{fn_name} called from host context"
);
return false;
}
true
}
static BULK_PORT_FD: std::sync::OnceLock<std::sync::Mutex<Option<std::fs::File>>> =
std::sync::OnceLock::new();
fn try_open_bulk_port() -> Option<std::fs::File> {
std::fs::OpenOptions::new()
.read(true)
.write(true)
.open("/dev/vport0p1")
.ok()
}
fn write_msg(msg_type: u32, payload: &[u8]) -> bool {
if !assert_guest_context("write_msg", msg_type) {
return false;
}
let _guard = GUEST_WRITE_LOCK.lock().unwrap_or_else(|e| e.into_inner());
write_to_bulk_port(msg_type, payload)
}
fn write_to_bulk_port(msg_type: u32, payload: &[u8]) -> bool {
let slot = BULK_PORT_FD.get_or_init(|| std::sync::Mutex::new(None));
let mut guard = slot.lock().unwrap_or_else(|e| e.into_inner());
if guard.is_none() {
match try_open_bulk_port() {
Some(f) => *guard = Some(f),
None => return false,
}
}
let f = guard.as_mut().expect("bulk port handle just installed");
let Ok(length_u32) = u32::try_from(payload.len()) else {
tracing::warn!(
len = payload.len(),
msg_type,
"write_to_bulk_port: payload exceeds u32::MAX; dropping"
);
return false;
};
let msg = ShmMessage {
msg_type,
length: length_u32,
crc32: crc32fast::hash(payload),
_pad: 0,
};
let header_bytes = msg.as_bytes();
let total = header_bytes.len() + payload.len();
let fd = std::os::unix::io::AsRawFd::as_raw_fd(f);
let mut iovs = [
std::io::IoSlice::new(header_bytes),
std::io::IoSlice::new(payload),
];
let mut bufs: &mut [std::io::IoSlice<'_>] = &mut iovs[..];
let mut written: usize = 0;
while !bufs.is_empty() {
let r = unsafe {
libc::writev(
fd,
bufs.as_ptr() as *const libc::iovec,
bufs.len() as libc::c_int,
)
};
if r < 0 {
let err = std::io::Error::last_os_error();
if err.kind() == std::io::ErrorKind::Interrupted {
continue;
}
tracing::warn!(
%err,
msg_type,
len = payload.len(),
"write_to_bulk_port: writev failed"
);
*guard = None;
return false;
}
if r == 0 {
tracing::warn!(
msg_type,
len = payload.len(),
written,
total,
"write_to_bulk_port: writev returned 0"
);
*guard = None;
return false;
}
let n = r as usize;
written += n;
std::io::IoSlice::advance_slices(&mut bufs, n);
}
debug_assert_eq!(written, total);
true
}
pub fn send_exit(code: i32) {
write_msg(MsgType::Exit.wire_value(), &code.to_le_bytes());
}
pub fn send_test_result(result: &crate::assert::AssertResult) {
match bincode::serde::encode_to_vec(result, bincode::config::standard()) {
Ok(bytes) => {
if bytes.len() > crate::vmm::bulk::MAX_BULK_FRAME_PAYLOAD as usize {
tracing::error!(
size = bytes.len(),
max = crate::vmm::bulk::MAX_BULK_FRAME_PAYLOAD,
"AssertResult exceeds bulk port frame limit, sending truncated verdict"
);
let truncated =
crate::assert::AssertResult::fail(crate::assert::AssertDetail::new(
crate::assert::DetailKind::Other,
format!(
"AssertResult bincode size {} exceeded bulk port limit {}; \
original details dropped",
bytes.len(),
crate::vmm::bulk::MAX_BULK_FRAME_PAYLOAD,
),
));
if let Ok(small) =
bincode::serde::encode_to_vec(&truncated, bincode::config::standard())
{
write_msg(MsgType::TestResult.wire_value(), &small);
}
} else {
write_msg(MsgType::TestResult.wire_value(), &bytes);
}
}
Err(e) => {
eprintln!("ktstr: bincode-encode AssertResult for bulk-port emit: {e}");
}
}
}
pub fn send_payload_metrics(metrics: &crate::test_support::PayloadMetrics) {
match bincode::serde::encode_to_vec(metrics, bincode::config::standard()) {
Ok(bytes) => {
write_msg(MsgType::PayloadMetrics.wire_value(), &bytes);
}
Err(e) => {
eprintln!("ktstr: bincode-encode PayloadMetrics for bulk-port emit: {e}");
}
}
}
pub fn send_profraw(buf: &[u8]) {
write_msg(MsgType::Profraw.wire_value(), buf);
}
pub fn send_stimulus(payload: &[u8]) {
write_msg(MsgType::Stimulus.wire_value(), payload);
}
pub(crate) fn send_raw_payload_output(raw: &crate::test_support::RawPayloadOutput) {
match bincode::serde::encode_to_vec(raw, bincode::config::standard()) {
Ok(bytes) => {
write_msg(MsgType::RawPayloadOutput.wire_value(), &bytes);
}
Err(e) => {
eprintln!("ktstr: bincode-encode RawPayloadOutput for bulk-port emit: {e}");
}
}
}
pub fn send_sched_exit(code: i32) {
write_msg(MsgType::SchedExit.wire_value(), &code.to_le_bytes());
}
pub fn send_scenario_start() {
write_msg(MsgType::ScenarioStart.wire_value(), &[]);
}
pub fn send_scenario_end(elapsed_ms: u64) {
write_msg(MsgType::ScenarioEnd.wire_value(), &elapsed_ms.to_le_bytes());
}
pub fn send_scenario_pause() {
write_msg(MsgType::ScenarioPause.wire_value(), &[]);
}
pub fn send_scenario_resume() {
write_msg(MsgType::ScenarioResume.wire_value(), &[]);
}
pub fn send_sys_rdy() -> bool {
write_msg(MsgType::SysRdy.wire_value(), &[])
}
pub fn send_kern_addrs(phys_base: u64, page_offset_base: u64) -> bool {
let mut payload = [0u8; 16];
payload[..8].copy_from_slice(&(phys_base.wrapping_add(1)).to_le_bytes());
payload[8..].copy_from_slice(&page_offset_base.to_le_bytes());
write_msg(super::wire::MSG_TYPE_KERN_ADDRS, &payload)
}
pub fn read_phys_base_from_iomem() -> Option<u64> {
let iomem = std::fs::read_to_string("/proc/iomem").ok()?;
#[cfg(target_arch = "x86_64")]
{
for line in iomem.lines() {
let line = line.trim();
if line.ends_with(": Kernel code") {
let range = line.split(':').next()?.trim();
let start = range.split('-').next()?.trim();
let phys_load = u64::from_str_radix(start, 16).ok()?;
return Some(phys_load.wrapping_sub(0x100_0000));
}
}
None
}
#[cfg(target_arch = "aarch64")]
{
let mut ram_start: Option<u64> = None;
let mut code_start: Option<u64> = None;
for line in iomem.lines() {
let line = line.trim();
if ram_start.is_none() && line.ends_with(": System RAM") {
let range = line.split(':').next()?.trim();
let start = range.split('-').next()?.trim();
ram_start = Some(u64::from_str_radix(start, 16).ok()?);
}
if line.ends_with(": Kernel code") {
let range = line.split(':').next()?.trim();
let start = range.split('-').next()?.trim();
code_start = Some(u64::from_str_radix(start, 16).ok()?);
}
}
Some(code_start?.wrapping_sub(ram_start?))
}
}
pub fn send_stdout_chunk(buf: &[u8]) -> bool {
write_msg(MsgType::Stdout.wire_value(), buf)
}
pub fn send_stderr_chunk(buf: &[u8]) -> bool {
write_msg(MsgType::Stderr.wire_value(), buf)
}
#[allow(dead_code)]
pub fn send_sched_log(buf: &[u8]) {
write_msg(MsgType::SchedLog.wire_value(), buf);
}
pub fn send_lifecycle(phase: LifecyclePhase, reason: &str) {
let mut buf = Vec::with_capacity(1 + reason.len());
buf.push(phase.wire_value());
buf.extend_from_slice(reason.as_bytes());
write_msg(MsgType::Lifecycle.wire_value(), &buf);
}
pub fn send_exec_exit(code: i32) {
write_msg(MsgType::ExecExit.wire_value(), &code.to_le_bytes());
}
#[allow(dead_code)]
pub fn send_dmesg(buf: &[u8]) {
write_msg(MsgType::Dmesg.wire_value(), buf);
}
#[allow(dead_code)]
pub fn send_probe_output(buf: &[u8]) {
write_msg(MsgType::ProbeOutput.wire_value(), buf);
}
static SNAPSHOT_REQUEST_COUNTER: std::sync::atomic::AtomicU32 =
std::sync::atomic::AtomicU32::new(1);
static SNAPSHOT_REQUEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
const SNAPSHOT_FAST_POLL_ITERS: u32 = 4;
const SNAPSHOT_FAST_POLL_INTERVAL: std::time::Duration = std::time::Duration::from_micros(100);
const SNAPSHOT_SLOW_POLL_INTERVAL: std::time::Duration = std::time::Duration::from_millis(5);
fn bounded_read_exact(
f: &mut std::fs::File,
buf: &mut [u8],
deadline: std::time::Instant,
) -> std::io::Result<()> {
use std::io::Read;
use std::os::unix::io::AsRawFd;
let fd = f.as_raw_fd();
let mut filled = 0usize;
let mut iter: u32 = 0;
while filled < buf.len() {
let now = std::time::Instant::now();
if now >= deadline {
return Err(std::io::Error::new(
std::io::ErrorKind::TimedOut,
format!(
"snapshot reply deadline elapsed after reading {filled} of {} header/payload bytes",
buf.len()
),
));
}
let remaining = deadline - now;
let interval = if iter < SNAPSHOT_FAST_POLL_ITERS {
SNAPSHOT_FAST_POLL_INTERVAL
} else {
SNAPSHOT_SLOW_POLL_INTERVAL
};
let slice = remaining.min(interval);
let ts = libc::timespec {
tv_sec: slice.as_secs() as libc::time_t,
tv_nsec: slice.subsec_nanos() as libc::c_long,
};
let mut pfd = libc::pollfd {
fd,
events: libc::POLLIN,
revents: 0,
};
let pr = unsafe { libc::ppoll(&mut pfd, 1, &ts, std::ptr::null()) };
iter = iter.saturating_add(1);
if pr < 0 {
let err = std::io::Error::last_os_error();
if err.kind() == std::io::ErrorKind::Interrupted {
continue;
}
return Err(err);
}
if pr == 0 {
continue;
}
match f.read(&mut buf[filled..]) {
Ok(0) => {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
format!(
"snapshot reply read returned 0 after {filled} of {} bytes",
buf.len()
),
));
}
Ok(n) => {
filled += n;
}
Err(e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
}
}
Ok(())
}
fn read_bulk_port_frame(
f: &mut std::fs::File,
deadline: std::time::Instant,
) -> std::io::Result<(u32, Vec<u8>)> {
let mut header = [0u8; std::mem::size_of::<ShmMessage>()];
bounded_read_exact(f, &mut header, deadline)?;
let msg = ShmMessage::read_from_bytes(&header).map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"ShmMessage::read_from_bytes failed (header underflow)",
)
})?;
let length = msg.length as usize;
if length > std::mem::size_of::<SnapshotReplyPayload>() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"TLV length {length} exceeds max payload {} for port-1 RX; \
rejecting before allocation to avoid guest OOM",
std::mem::size_of::<SnapshotReplyPayload>()
),
));
}
let mut payload = vec![0u8; length];
if length > 0 {
bounded_read_exact(f, &mut payload, deadline)?;
}
let computed = crc32fast::hash(&payload);
if computed != msg.crc32 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"TLV CRC mismatch: header crc=0x{:08x} computed=0x{computed:08x} length={length}",
msg.crc32
),
));
}
Ok((msg.msg_type, payload))
}
pub fn request_snapshot(
kind: u32,
tag: &str,
timeout: std::time::Duration,
) -> SnapshotRequestResult {
if !is_guest() {
return SnapshotRequestResult::TransportError {
reason: "request_snapshot called from host context (virtio-console port 1 \
is reachable only from inside the guest)"
.into(),
};
}
let _guard = SNAPSHOT_REQUEST_LOCK
.lock()
.unwrap_or_else(|e| e.into_inner());
let mut request_id = SNAPSHOT_REQUEST_COUNTER.fetch_add(1, std::sync::atomic::Ordering::AcqRel);
if request_id == 0 {
request_id = SNAPSHOT_REQUEST_COUNTER.fetch_add(1, std::sync::atomic::Ordering::AcqRel);
}
let tag_bytes = tag.as_bytes();
let tag_len = tag_bytes.len().min(SNAPSHOT_TAG_MAX);
let mut tag_buf = [0u8; SNAPSHOT_TAG_MAX];
tag_buf[..tag_len].copy_from_slice(&tag_bytes[..tag_len]);
let payload = SnapshotRequestPayload {
request_id,
kind,
tag: tag_buf,
};
let bytes = payload.as_bytes();
write_msg(MsgType::SnapshotRequest.wire_value(), bytes);
let read_slot = BULK_PORT_FD.get_or_init(|| std::sync::Mutex::new(None));
let mut read_guard = read_slot.lock().unwrap_or_else(|e| e.into_inner());
if read_guard.is_none() {
match try_open_bulk_port() {
Some(f) => *read_guard = Some(f),
None => {
return SnapshotRequestResult::TransportError {
reason: "/dev/vport0p1 not yet open \
(multiport handshake still in flight)"
.into(),
};
}
}
}
let f = read_guard
.as_mut()
.expect("bulk port handle just installed");
let deadline = std::time::Instant::now() + timeout;
loop {
let now = std::time::Instant::now();
if now >= deadline {
return SnapshotRequestResult::TransportError {
reason: format!(
"host did not deliver matching snapshot reply within {timeout:?} \
(request_id={request_id}, kind={kind})"
),
};
}
let frame = match read_bulk_port_frame(f, deadline) {
Ok(frame) => frame,
Err(e) if e.kind() == std::io::ErrorKind::TimedOut => {
return SnapshotRequestResult::TransportError {
reason: format!(
"snapshot reply deadline elapsed before frame complete \
(request_id={request_id}, kind={kind}): {e}"
),
};
}
Err(e) => {
*read_guard = None;
return SnapshotRequestResult::TransportError {
reason: format!("snapshot reply read failed (request_id={request_id}): {e}"),
};
}
};
let (msg_type, frame_payload) = frame;
if msg_type != MSG_TYPE_SNAPSHOT_REPLY {
tracing::warn!(
msg_type,
len = frame_payload.len(),
request_id,
"request_snapshot: ignoring unexpected TLV on port 1 RX (only \
SnapshotReply is expected on this transport in current protocol)"
);
continue;
}
if frame_payload.len() != std::mem::size_of::<SnapshotReplyPayload>() {
tracing::warn!(
request_id,
got = frame_payload.len(),
want = std::mem::size_of::<SnapshotReplyPayload>(),
"request_snapshot: malformed reply payload size; ignoring"
);
continue;
}
let reply = match SnapshotReplyPayload::read_from_bytes(&frame_payload) {
Ok(r) => r,
Err(_) => {
tracing::warn!(
request_id,
"request_snapshot: SnapshotReplyPayload::read_from_bytes failed; ignoring"
);
continue;
}
};
if reply.request_id != request_id {
tracing::warn!(
expected = request_id,
got = reply.request_id,
"request_snapshot: stale reply id (likely a leftover from a prior \
request that timed out on the guest side); ignoring"
);
continue;
}
return match reply.status {
SNAPSHOT_STATUS_OK => SnapshotRequestResult::Ok,
SNAPSHOT_STATUS_ERR => {
let len = reply
.reason
.iter()
.position(|&b| b == 0)
.unwrap_or(SNAPSHOT_REASON_MAX);
let reason = String::from_utf8_lossy(&reply.reason[..len]).to_string();
SnapshotRequestResult::HostError { reason }
}
other => SnapshotRequestResult::TransportError {
reason: format!(
"host reply with unknown status {other} \
(expected OK={SNAPSHOT_STATUS_OK} or ERR={SNAPSHOT_STATUS_ERR})"
),
},
};
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn send_exit_from_host_context_is_noop() {
let _g = IsGuestOverrideGuard::new(false);
send_exit(0);
send_exit(-1);
}
#[test]
fn send_test_result_from_host_context_is_noop() {
let _g = IsGuestOverrideGuard::new(false);
send_test_result(&crate::assert::AssertResult::pass());
}
#[test]
fn send_payload_metrics_from_host_context_is_noop() {
let _g = IsGuestOverrideGuard::new(false);
let pm = crate::test_support::PayloadMetrics {
payload_index: 0,
metrics: vec![],
exit_code: 0,
};
send_payload_metrics(&pm);
}
#[test]
fn send_profraw_from_host_context_is_noop() {
let _g = IsGuestOverrideGuard::new(false);
send_profraw(b"\x01\x02\x03");
}
#[test]
fn send_stimulus_from_host_context_is_noop() {
let _g = IsGuestOverrideGuard::new(false);
send_stimulus(&[0u8; 24]);
}
#[test]
fn send_raw_payload_output_from_host_context_is_noop() {
let _g = IsGuestOverrideGuard::new(false);
let raw = crate::test_support::RawPayloadOutput {
payload_index: 0,
stdout: String::new(),
stderr: String::new(),
hint: None,
metric_hints: vec![],
metric_bounds: None,
};
send_raw_payload_output(&raw);
}
#[test]
fn send_sched_exit_from_host_context_is_noop() {
let _g = IsGuestOverrideGuard::new(false);
send_sched_exit(0);
send_sched_exit(-1);
}
#[test]
fn send_scenario_start_from_host_context_is_noop() {
let _g = IsGuestOverrideGuard::new(false);
send_scenario_start();
}
#[test]
fn send_scenario_end_from_host_context_is_noop() {
let _g = IsGuestOverrideGuard::new(false);
send_scenario_end(0);
send_scenario_end(u64::MAX);
}
#[test]
fn send_sys_rdy_from_host_context_returns_false() {
let _g = IsGuestOverrideGuard::new(false);
assert!(
!send_sys_rdy(),
"host-context call must return false so the guest's \
retry loop can distinguish 'wrote' from 'noop'"
);
}
#[test]
fn send_stdout_chunk_from_host_context_returns_false() {
let _g = IsGuestOverrideGuard::new(false);
assert!(!send_stdout_chunk(b"hello"));
}
#[test]
fn send_stderr_chunk_from_host_context_returns_false() {
let _g = IsGuestOverrideGuard::new(false);
assert!(!send_stderr_chunk(b"oops"));
}
#[test]
fn send_sched_log_from_host_context_is_noop() {
let _g = IsGuestOverrideGuard::new(false);
send_sched_log(b"---SCHED_OUTPUT_START---\n");
}
#[test]
fn send_lifecycle_from_host_context_is_noop() {
let _g = IsGuestOverrideGuard::new(false);
send_lifecycle(LifecyclePhase::InitStarted, "");
send_lifecycle(LifecyclePhase::PayloadStarting, "");
send_lifecycle(LifecyclePhase::SchedulerDied, "");
send_lifecycle(LifecyclePhase::SchedulerNotAttached, "verifier rejected");
}
#[test]
fn send_exec_exit_from_host_context_is_noop() {
let _g = IsGuestOverrideGuard::new(false);
send_exec_exit(0);
send_exec_exit(-1);
}
#[test]
fn send_dmesg_from_host_context_is_noop() {
let _g = IsGuestOverrideGuard::new(false);
send_dmesg(b"[ 0.000000] Linux version 6.16.0\n");
}
#[test]
fn send_probe_output_from_host_context_is_noop() {
let _g = IsGuestOverrideGuard::new(false);
send_probe_output(b"{}\n");
}
#[test]
fn request_snapshot_from_host_context_returns_transport_error() {
let _g = IsGuestOverrideGuard::new(false);
let r = request_snapshot(0, "tag", std::time::Duration::from_millis(0));
match r {
SnapshotRequestResult::TransportError { .. } => {}
other => panic!("expected TransportError from host context, got {other:?}"),
}
}
#[test]
fn read_bulk_port_frame_rejects_oversized_length_before_alloc() {
use std::os::unix::io::FromRawFd;
let mut fds = [0i32; 2];
let r = unsafe { libc::pipe(fds.as_mut_ptr()) };
assert_eq!(r, 0, "pipe(2) failed: {}", std::io::Error::last_os_error());
let mut read_end = unsafe { std::fs::File::from_raw_fd(fds[0]) };
let mut write_end = unsafe { std::fs::File::from_raw_fd(fds[1]) };
let header = ShmMessage {
msg_type: MSG_TYPE_SNAPSHOT_REPLY,
length: u32::MAX,
crc32: 0,
_pad: 0,
};
use std::io::Write;
write_end
.write_all(header.as_bytes())
.expect("write forged header");
drop(write_end);
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(5);
let err = read_bulk_port_frame(&mut read_end, deadline)
.expect_err("oversized length must be rejected");
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
let msg = err.to_string();
assert!(
msg.contains("exceeds max payload"),
"error must explain the cap, got: {msg}"
);
}
#[test]
fn read_bulk_port_frame_accepts_exact_max_payload() {
use std::os::unix::io::FromRawFd;
let mut fds = [0i32; 2];
let r = unsafe { libc::pipe(fds.as_mut_ptr()) };
assert_eq!(r, 0, "pipe(2) failed: {}", std::io::Error::last_os_error());
let mut read_end = unsafe { std::fs::File::from_raw_fd(fds[0]) };
let mut write_end = unsafe { std::fs::File::from_raw_fd(fds[1]) };
let payload = vec![0u8; std::mem::size_of::<SnapshotReplyPayload>()];
let header = ShmMessage {
msg_type: MSG_TYPE_SNAPSHOT_REPLY,
length: payload.len() as u32,
crc32: crc32fast::hash(&payload),
_pad: 0,
};
use std::io::Write;
write_end.write_all(header.as_bytes()).expect("header");
write_end.write_all(&payload).expect("payload");
drop(write_end);
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(5);
let (msg_type, body) =
read_bulk_port_frame(&mut read_end, deadline).expect("exact-size payload must succeed");
assert_eq!(msg_type, MSG_TYPE_SNAPSHOT_REPLY);
assert_eq!(body.len(), std::mem::size_of::<SnapshotReplyPayload>());
}
#[test]
fn is_guest_override_round_trips_through_thread_local() {
{
let _g = IsGuestOverrideGuard::new(false);
assert!(!is_guest());
}
{
let _g = IsGuestOverrideGuard::new(true);
assert!(is_guest());
}
}
#[test]
fn is_guest_override_guards_nest_correctly() {
let _outer = IsGuestOverrideGuard::new(true);
assert!(is_guest());
{
let _inner = IsGuestOverrideGuard::new(false);
assert!(!is_guest());
}
assert!(is_guest());
}
}