use super::*;
use crate::vmm::wire::{
FRAME_HEADER_SIZE, MSG_TYPE_SNAPSHOT_REPLY, SNAPSHOT_KIND_CAPTURE, SNAPSHOT_KIND_NONE,
SNAPSHOT_KIND_WATCH, SNAPSHOT_REASON_MAX, SNAPSHOT_STATUS_ERR, SNAPSHOT_STATUS_OK,
SNAPSHOT_TAG_MAX, ShmMessage, SnapshotReplyPayload, SnapshotRequestPayload,
};
use zerocopy::{FromBytes, IntoBytes};
fn make_request_bytes(request_id: u32, kind: u32, tag: &str) -> Vec<u8> {
let tag_bytes = tag.as_bytes();
let mut tag_buf = [0u8; SNAPSHOT_TAG_MAX];
let n = tag_bytes.len().min(SNAPSHOT_TAG_MAX);
tag_buf[..n].copy_from_slice(&tag_bytes[..n]);
SnapshotRequestPayload {
request_id,
kind,
tag: tag_buf,
}
.as_bytes()
.to_vec()
}
#[test]
fn decode_capture_request_round_trip() {
let bytes = make_request_bytes(7, SNAPSHOT_KIND_CAPTURE, "snap_1");
let req = decode_snapshot_request(&bytes).expect("valid request decodes");
assert_eq!(req.request_id, 7);
assert_eq!(req.kind, SNAPSHOT_KIND_CAPTURE);
assert_eq!(req.tag, "snap_1");
}
#[test]
fn decode_watch_request_round_trip() {
let bytes = make_request_bytes(99, SNAPSHOT_KIND_WATCH, "scx_root");
let req = decode_snapshot_request(&bytes).expect("valid request decodes");
assert_eq!(req.kind, SNAPSHOT_KIND_WATCH);
assert_eq!(req.tag, "scx_root");
}
#[test]
fn decode_rejects_undersized_payload() {
let mut bytes = make_request_bytes(1, SNAPSHOT_KIND_CAPTURE, "x");
bytes.pop();
assert!(decode_snapshot_request(&bytes).is_none());
}
#[test]
fn decode_rejects_oversized_payload() {
let mut bytes = make_request_bytes(1, SNAPSHOT_KIND_CAPTURE, "x");
bytes.push(0xAA);
assert!(decode_snapshot_request(&bytes).is_none());
}
#[test]
fn decode_rejects_zero_request_id() {
let bytes = make_request_bytes(0, SNAPSHOT_KIND_CAPTURE, "x");
assert!(decode_snapshot_request(&bytes).is_none());
}
#[test]
fn decode_rejects_kind_none() {
let bytes = make_request_bytes(1, SNAPSHOT_KIND_NONE, "x");
assert!(decode_snapshot_request(&bytes).is_none());
}
#[test]
fn decode_accepts_unknown_kind_for_dispatch_handling() {
let bytes = make_request_bytes(42, 0xDEAD_BEEF, "tag");
let req = decode_snapshot_request(&bytes).expect("decode succeeds");
assert_eq!(req.kind, 0xDEAD_BEEF);
assert_eq!(req.tag, "tag");
}
#[test]
fn decode_full_buffer_tag_uses_full_length() {
let long = "a".repeat(SNAPSHOT_TAG_MAX);
let bytes = make_request_bytes(1, SNAPSHOT_KIND_CAPTURE, &long);
let req = decode_snapshot_request(&bytes).expect("decode succeeds");
assert_eq!(req.tag.len(), SNAPSHOT_TAG_MAX);
assert!(req.tag.chars().all(|c| c == 'a'));
}
#[test]
fn frame_reply_size_and_crc() {
let bytes = frame_snapshot_reply(123, SNAPSHOT_STATUS_OK, "");
assert_eq!(
bytes.len(),
FRAME_HEADER_SIZE + std::mem::size_of::<SnapshotReplyPayload>()
);
let header = ShmMessage::read_from_bytes(&bytes[..FRAME_HEADER_SIZE]).expect("header decodes");
assert_eq!(header.msg_type, MSG_TYPE_SNAPSHOT_REPLY);
assert_eq!(
header.length as usize,
std::mem::size_of::<SnapshotReplyPayload>()
);
let payload_bytes = &bytes[FRAME_HEADER_SIZE..];
assert_eq!(header.crc32, crc32fast::hash(payload_bytes));
}
#[test]
fn frame_reply_payload_round_trip() {
let bytes = frame_snapshot_reply(0xCAFE_BABE, SNAPSHOT_STATUS_ERR, "rendezvous timeout");
let payload_bytes = &bytes[FRAME_HEADER_SIZE..];
let reply = SnapshotReplyPayload::read_from_bytes(payload_bytes).expect("payload decodes");
assert_eq!(reply.request_id, 0xCAFE_BABE);
assert_eq!(reply.status, SNAPSHOT_STATUS_ERR);
let len = reply
.reason
.iter()
.position(|&b| b == 0)
.unwrap_or(SNAPSHOT_REASON_MAX);
assert_eq!(&reply.reason[..len], b"rendezvous timeout");
}
#[test]
fn frame_reply_truncates_long_reason() {
let long = "x".repeat(SNAPSHOT_REASON_MAX + 16);
let bytes = frame_snapshot_reply(1, SNAPSHOT_STATUS_ERR, &long);
let payload_bytes = &bytes[FRAME_HEADER_SIZE..];
let reply = SnapshotReplyPayload::read_from_bytes(payload_bytes).expect("payload decodes");
assert_eq!(reply.reason.len(), SNAPSHOT_REASON_MAX);
assert!(reply.reason.iter().all(|&b| b == b'x'));
}
#[test]
fn frame_reply_empty_reason_zero_pads() {
let bytes = frame_snapshot_reply(1, SNAPSHOT_STATUS_OK, "");
let payload_bytes = &bytes[FRAME_HEADER_SIZE..];
let reply = SnapshotReplyPayload::read_from_bytes(payload_bytes).expect("payload decodes");
assert!(reply.reason.iter().all(|&b| b == 0));
}
#[test]
fn frame_reply_preserves_fix_c_diagnostic_substrings() {
let fix_c = "symbol 'jiffies_64' link_kva 0xffffffff812d2000 is in \
x86_64 kernel high-half but kern_virt_kaslr Arc has not \
published a non-zero slide (coord_kaslr_offset() == 0); \
refusing to arm DR at link-time KVA — would never match \
guest writes under KASLR-on. If nokaslr semantics intended, \
set `#[ktstr_test(kaslr = false)]`.";
let bytes = frame_snapshot_reply(1, SNAPSHOT_STATUS_ERR, fix_c);
let payload_bytes = &bytes[FRAME_HEADER_SIZE..];
let reply = SnapshotReplyPayload::read_from_bytes(payload_bytes).expect("payload decodes");
let len = reply
.reason
.iter()
.position(|&b| b == 0)
.unwrap_or(SNAPSHOT_REASON_MAX);
let s = std::str::from_utf8(&reply.reason[..len]).expect("Fix C diagnostic is valid UTF-8");
assert!(
s.contains("kaslr_offset"),
"diagnostic lost 'kaslr_offset' substring; \
tests/kaslr_axis_e2e.rs greps for it (recorded reason: {s:?})"
);
assert!(
s.contains("kern_virt_kaslr"),
"diagnostic lost 'kern_virt_kaslr' substring; \
tests/kaslr_axis_e2e.rs greps for it (recorded reason: {s:?})"
);
assert!(
s.contains("#[ktstr_test(kaslr = false)]"),
"diagnostic lost the actionable remediation tip; \
test author would not know how to unblock (recorded reason: {s:?})"
);
}