use std::sync::Arc;
use zerocopy::FromBytes;
use super::bulk::MAX_BULK_FRAME_PAYLOAD;
use super::pi_mutex::PiMutex;
use super::virtio_console::{
SIGNAL_BPF_WRITE_DONE, SIGNAL_VC_DUMP, SIGNAL_VC_SHUTDOWN, VirtioConsole,
};
use super::wire::{FRAME_HEADER_SIZE, ShmEntry, ShmMessage};
#[derive(Debug, Clone, Default)]
pub struct BulkDrainResult {
pub entries: Vec<ShmEntry>,
}
#[allow(dead_code)]
pub fn drain_bulk(dev: &mut VirtioConsole) -> BulkDrainResult {
let bytes = dev.drain_bulk();
parse_tlv_stream(&bytes)
}
pub fn parse_tlv_stream(buf: &[u8]) -> BulkDrainResult {
let mut entries: Vec<ShmEntry> = Vec::new();
let mut pos = 0usize;
while pos.saturating_add(FRAME_HEADER_SIZE) <= buf.len() {
let hdr_end = pos + FRAME_HEADER_SIZE;
let hdr_slice = &buf[pos..hdr_end];
let Ok(msg) = ShmMessage::read_from_bytes(hdr_slice) else {
break;
};
if msg._pad != 0 {
tracing::warn!(
msg_type = msg.msg_type,
length = msg.length,
pad = msg._pad,
"parse_tlv_stream: non-zero _pad in frame header; possible hostile guest covert channel"
);
}
if msg.length > MAX_BULK_FRAME_PAYLOAD {
tracing::warn!(
msg_type = msg.msg_type,
length = msg.length,
cap = MAX_BULK_FRAME_PAYLOAD,
"parse_tlv_stream: dropping oversized frame; stopping walk"
);
break;
}
if (msg.length as usize) > buf.len().saturating_sub(hdr_end) {
break;
}
let payload_end = hdr_end + msg.length as usize;
let payload = buf[hdr_end..payload_end].to_vec();
let computed_crc = crc32fast::hash(&payload);
let crc_ok = computed_crc == msg.crc32;
if !crc_ok {
tracing::warn!(
msg_type = msg.msg_type,
length = msg.length,
expected_crc = msg.crc32,
computed_crc,
"parse_tlv_stream: per-frame CRC mismatch; surfacing crc_ok=false"
);
}
entries.push(ShmEntry {
msg_type: msg.msg_type,
payload,
crc_ok,
});
pos = payload_end;
}
BulkDrainResult { entries }
}
pub fn request_dump(virtio_con: &Arc<PiMutex<VirtioConsole>>) {
virtio_con.lock().queue_input(&[SIGNAL_VC_DUMP]);
}
pub fn request_shutdown(virtio_con: &Arc<PiMutex<VirtioConsole>>) {
virtio_con.lock().queue_input(&[SIGNAL_VC_SHUTDOWN]);
}
pub fn request_bpf_map_write_done(virtio_con: &Arc<PiMutex<VirtioConsole>>) {
virtio_con.lock().queue_input(&[SIGNAL_BPF_WRITE_DONE]);
}
#[cfg(test)]
mod tests {
use super::super::wire::{MSG_TYPE_EXIT, MSG_TYPE_STIMULUS};
use super::*;
use zerocopy::IntoBytes;
fn frame_bytes(msg_type: u32, payload: &[u8]) -> Vec<u8> {
let f = ShmMessage {
msg_type,
length: payload.len() as u32,
crc32: crc32fast::hash(payload),
_pad: 0,
};
let mut v = Vec::with_capacity(FRAME_HEADER_SIZE + payload.len());
v.extend_from_slice(f.as_bytes());
v.extend_from_slice(payload);
v
}
#[test]
fn parse_empty_buffer_yields_no_entries() {
let r = parse_tlv_stream(&[]);
assert!(r.entries.is_empty());
}
#[test]
fn parse_single_frame_one_entry() {
let bytes = frame_bytes(MSG_TYPE_EXIT, &42i32.to_le_bytes());
let r = parse_tlv_stream(&bytes);
assert_eq!(r.entries.len(), 1);
assert_eq!(r.entries[0].msg_type, MSG_TYPE_EXIT);
assert!(r.entries[0].crc_ok);
assert_eq!(r.entries[0].payload, 42i32.to_le_bytes());
}
#[test]
fn parse_multiple_frames_preserve_order() {
let mut buf = Vec::new();
buf.extend(frame_bytes(MSG_TYPE_STIMULUS, b"first"));
buf.extend(frame_bytes(MSG_TYPE_EXIT, b"second"));
buf.extend(frame_bytes(MSG_TYPE_STIMULUS, b"third"));
let r = parse_tlv_stream(&buf);
assert_eq!(r.entries.len(), 3);
assert_eq!(r.entries[0].payload, b"first");
assert_eq!(r.entries[1].payload, b"second");
assert_eq!(r.entries[2].payload, b"third");
}
#[test]
fn parse_drops_trailing_partial_frame() {
let bytes = frame_bytes(MSG_TYPE_EXIT, b"complete");
let truncated = &bytes[..bytes.len() - 2];
let r = parse_tlv_stream(truncated);
assert!(r.entries.is_empty());
}
#[test]
fn parse_crc_mismatch_marks_entry_continues_walk() {
let mut bad = frame_bytes(MSG_TYPE_EXIT, b"payload");
bad[FRAME_HEADER_SIZE] ^= 0xFF;
let mut good = frame_bytes(MSG_TYPE_STIMULUS, b"valid");
let mut combined = Vec::new();
combined.append(&mut bad);
combined.append(&mut good);
let r = parse_tlv_stream(&combined);
assert_eq!(r.entries.len(), 2);
assert!(!r.entries[0].crc_ok);
assert!(r.entries[1].crc_ok);
}
#[test]
fn parse_zero_length_payload() {
let bytes = frame_bytes(MSG_TYPE_EXIT, b"");
assert_eq!(bytes.len(), FRAME_HEADER_SIZE);
let r = parse_tlv_stream(&bytes);
assert_eq!(r.entries.len(), 1);
assert!(r.entries[0].payload.is_empty());
assert!(r.entries[0].crc_ok);
}
#[test]
fn drain_bulk_empty_device_yields_no_entries() {
let mut dev = VirtioConsole::new();
let r = drain_bulk(&mut dev);
assert!(r.entries.is_empty());
}
#[test]
fn parse_rejects_oversized_announced_length() {
use zerocopy::IntoBytes;
let bad = ShmMessage {
msg_type: MSG_TYPE_STIMULUS,
length: u32::MAX,
crc32: 0,
_pad: 0,
};
let r = parse_tlv_stream(bad.as_bytes());
assert!(
r.entries.is_empty(),
"header announcing u32::MAX must be rejected without producing entries"
);
let just_over = ShmMessage {
msg_type: MSG_TYPE_STIMULUS,
length: MAX_BULK_FRAME_PAYLOAD + 1,
crc32: 0,
_pad: 0,
};
let mut buf = Vec::with_capacity(FRAME_HEADER_SIZE + just_over.length as usize);
buf.extend_from_slice(just_over.as_bytes());
buf.resize(FRAME_HEADER_SIZE + just_over.length as usize, 0xAA);
let r2 = parse_tlv_stream(&buf);
assert!(
r2.entries.is_empty(),
"header announcing cap + 1 must be rejected by the per-frame cap check"
);
}
#[test]
fn parse_accepts_at_cap_payload() {
let max_payload = vec![0x55u8; MAX_BULK_FRAME_PAYLOAD as usize];
let bytes = frame_bytes(MSG_TYPE_STIMULUS, &max_payload);
let r = parse_tlv_stream(&bytes);
assert_eq!(
r.entries.len(),
1,
"frame with length == cap must be accepted"
);
assert_eq!(r.entries[0].payload.len(), MAX_BULK_FRAME_PAYLOAD as usize);
assert!(r.entries[0].crc_ok);
}
#[test]
fn parse_returns_valid_frame_then_stops_at_oversized() {
use zerocopy::IntoBytes;
let mut combined = Vec::new();
combined.extend_from_slice(&frame_bytes(MSG_TYPE_EXIT, b"valid"));
let bad = ShmMessage {
msg_type: MSG_TYPE_STIMULUS,
length: u32::MAX,
crc32: 0,
_pad: 0,
};
combined.extend_from_slice(bad.as_bytes());
combined.extend_from_slice(b"residue");
let r = parse_tlv_stream(&combined);
assert_eq!(
r.entries.len(),
1,
"valid frame must be returned even though the next header is bogus"
);
assert_eq!(r.entries[0].payload, b"valid");
assert!(r.entries[0].crc_ok);
}
#[test]
fn parse_stops_at_oversized_does_not_return_subsequent_valid() {
use zerocopy::IntoBytes;
let bad = ShmMessage {
msg_type: MSG_TYPE_STIMULUS,
length: u32::MAX,
crc32: 0,
_pad: 0,
};
let mut combined = Vec::new();
combined.extend_from_slice(bad.as_bytes());
combined.extend_from_slice(&frame_bytes(MSG_TYPE_EXIT, b"valid"));
let r = parse_tlv_stream(&combined);
assert!(
r.entries.is_empty(),
"no entries: parser must stop at the oversized header and not resume on the trailing valid frame"
);
}
#[test]
fn parse_recognises_all_new_msg_type_variants() {
use super::super::wire::{
MSG_TYPE_DMESG, MSG_TYPE_EXEC_EXIT, MSG_TYPE_LIFECYCLE, MSG_TYPE_PROBE_OUTPUT,
MSG_TYPE_SCHED_LOG, MSG_TYPE_STDERR, MSG_TYPE_STDOUT, MsgType,
};
let cases: &[(u32, MsgType, &[u8])] = &[
(MSG_TYPE_STDOUT, MsgType::Stdout, b"hello\n"),
(MSG_TYPE_STDERR, MsgType::Stderr, b"error\n"),
(MSG_TYPE_SCHED_LOG, MsgType::SchedLog, b"---SCHED---\n"),
(MSG_TYPE_LIFECYCLE, MsgType::Lifecycle, &[1u8]),
(MSG_TYPE_EXEC_EXIT, MsgType::ExecExit, &0i32.to_le_bytes()),
(MSG_TYPE_DMESG, MsgType::Dmesg, b"[ 0.000000] Linux\n"),
(MSG_TYPE_PROBE_OUTPUT, MsgType::ProbeOutput, b"{\"k\":1}\n"),
];
for (raw, typed, payload) in cases {
let bytes = frame_bytes(*raw, payload);
let r = parse_tlv_stream(&bytes);
assert_eq!(
r.entries.len(),
1,
"single-frame parse failed for {typed:?}",
);
assert!(r.entries[0].crc_ok, "CRC must round-trip for {typed:?}");
assert_eq!(
r.entries[0].payload, *payload,
"payload byte mismatch for {typed:?}",
);
assert_eq!(
MsgType::from_wire(r.entries[0].msg_type),
Some(*typed),
"from_wire decode mismatch for {typed:?}",
);
}
}
#[test]
fn parsed_entries_match_is_coordinator_internal_classifier() {
use super::super::wire::{
MSG_TYPE_SNAPSHOT_REQUEST, MSG_TYPE_SYS_RDY, MSG_TYPE_TEST_RESULT, MsgType,
};
let internal_raw = frame_bytes(MSG_TYPE_SNAPSHOT_REQUEST, &[0u8; 72]);
let r = parse_tlv_stream(&internal_raw);
assert_eq!(r.entries.len(), 1);
let typed = MsgType::from_wire(r.entries[0].msg_type).unwrap();
assert!(typed.is_coordinator_internal());
let internal_sys_rdy = frame_bytes(MSG_TYPE_SYS_RDY, b"");
let r = parse_tlv_stream(&internal_sys_rdy);
assert_eq!(r.entries.len(), 1);
let typed = MsgType::from_wire(r.entries[0].msg_type).unwrap();
assert!(typed.is_coordinator_internal());
let verdict = frame_bytes(MSG_TYPE_TEST_RESULT, b"\x00");
let r = parse_tlv_stream(&verdict);
assert_eq!(r.entries.len(), 1);
let typed = MsgType::from_wire(r.entries[0].msg_type).unwrap();
assert!(!typed.is_coordinator_internal());
}
}