use super::device::*;
use crate::vmm::net_config::NetConfig;
use proptest::prelude::*;
use std::sync::atomic::Ordering;
use virtio_bindings::virtio_config::VIRTIO_F_VERSION_1;
use virtio_bindings::virtio_mmio::{
VIRTIO_MMIO_DRIVER_FEATURES, VIRTIO_MMIO_DRIVER_FEATURES_SEL, VIRTIO_MMIO_QUEUE_AVAIL_LOW,
VIRTIO_MMIO_QUEUE_DESC_LOW, VIRTIO_MMIO_QUEUE_NOTIFY, VIRTIO_MMIO_QUEUE_NUM,
VIRTIO_MMIO_QUEUE_READY, VIRTIO_MMIO_QUEUE_SEL, VIRTIO_MMIO_QUEUE_USED_LOW, VIRTIO_MMIO_STATUS,
};
use virtio_bindings::virtio_net::VIRTIO_NET_F_MAC;
use vm_memory::{Bytes, GuestAddress, GuestMemoryMmap};
const GUEST_MEM_SIZE: usize = 0x10_0000; const TX_DESC_BASE: u64 = 0x1000;
const TX_AVAIL_BASE: u64 = 0x2000;
const TX_USED_BASE: u64 = 0x3000;
const TX_FRAME_BUF_BASE: u64 = 0x4000; const RX_DESC_BASE: u64 = 0x6000;
const RX_AVAIL_BASE: u64 = 0x7000;
const RX_USED_BASE: u64 = 0x8000;
const RX_BUF_BASE: u64 = 0x9000;
const PROPTEST_QUEUE_SIZE: u16 = 16;
const MAX_CHAIN_LEN: usize = 8;
const VRING_DESC_F_NEXT: u16 = 1;
const VRING_DESC_F_WRITE: u16 = 2;
#[derive(Debug, Clone, Copy)]
struct FuzzDesc {
addr: u64,
len: u32,
flags: u16,
next: u16,
}
fn fuzz_desc_strategy() -> impl Strategy<Value = FuzzDesc> {
(
0u64..(1u64 << 24),
0u32..(8 * 1024 * 1024),
0u16..8,
any::<u16>(),
)
.prop_map(|(addr, len, flags, next)| FuzzDesc {
addr,
len,
flags,
next,
})
}
fn fuzz_chain_strategy() -> impl Strategy<Value = Vec<FuzzDesc>> {
prop::collection::vec(fuzz_desc_strategy(), 1..=MAX_CHAIN_LEN)
}
fn fuzz_header_strategy() -> impl Strategy<Value = [u8; 12]> {
any::<[u8; 12]>()
}
fn build_fuzz_fixture() -> (VirtioNet, GuestMemoryMmap) {
let mem = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), GUEST_MEM_SIZE)])
.expect("create proptest guest mem");
let mut dev = VirtioNet::new(NetConfig::default());
dev.set_mem(mem.clone());
init_until_features_ok(&mut dev);
program_queues(&mut dev);
write_reg(&mut dev, VIRTIO_MMIO_STATUS, S_OK);
(dev, mem)
}
fn init_until_features_ok(dev: &mut VirtioNet) {
write_reg(dev, VIRTIO_MMIO_STATUS, S_ACK);
write_reg(dev, VIRTIO_MMIO_STATUS, S_DRV);
write_reg(dev, VIRTIO_MMIO_DRIVER_FEATURES_SEL, 0);
write_reg(dev, VIRTIO_MMIO_DRIVER_FEATURES, 1u32 << VIRTIO_NET_F_MAC);
write_reg(dev, VIRTIO_MMIO_DRIVER_FEATURES_SEL, 1);
write_reg(
dev,
VIRTIO_MMIO_DRIVER_FEATURES,
1u32 << (VIRTIO_F_VERSION_1 - 32),
);
write_reg(dev, VIRTIO_MMIO_STATUS, S_FEAT);
}
fn program_queues(dev: &mut VirtioNet) {
write_reg(dev, VIRTIO_MMIO_QUEUE_SEL, RXQ as u32);
write_reg(dev, VIRTIO_MMIO_QUEUE_NUM, PROPTEST_QUEUE_SIZE as u32);
write_reg(dev, VIRTIO_MMIO_QUEUE_DESC_LOW, RX_DESC_BASE as u32);
write_reg(dev, VIRTIO_MMIO_QUEUE_AVAIL_LOW, RX_AVAIL_BASE as u32);
write_reg(dev, VIRTIO_MMIO_QUEUE_USED_LOW, RX_USED_BASE as u32);
write_reg(dev, VIRTIO_MMIO_QUEUE_READY, 1);
write_reg(dev, VIRTIO_MMIO_QUEUE_SEL, TXQ as u32);
write_reg(dev, VIRTIO_MMIO_QUEUE_NUM, PROPTEST_QUEUE_SIZE as u32);
write_reg(dev, VIRTIO_MMIO_QUEUE_DESC_LOW, TX_DESC_BASE as u32);
write_reg(dev, VIRTIO_MMIO_QUEUE_AVAIL_LOW, TX_AVAIL_BASE as u32);
write_reg(dev, VIRTIO_MMIO_QUEUE_USED_LOW, TX_USED_BASE as u32);
write_reg(dev, VIRTIO_MMIO_QUEUE_READY, 1);
}
fn write_reg(dev: &mut VirtioNet, offset: u32, val: u32) {
dev.mmio_write(offset as u64, &val.to_le_bytes());
}
fn read_used_idx(mem: &GuestMemoryMmap, used_base: u64) -> u16 {
mem.read_obj::<u16>(GuestAddress(used_base + 2))
.expect("read used.idx")
}
fn write_desc(
mem: &GuestMemoryMmap,
table_base: u64,
idx: u16,
addr: u64,
len: u32,
flags: u16,
next: u16,
) {
let off = table_base + (idx as u64) * 16;
let mut buf = [0u8; 16];
buf[0..8].copy_from_slice(&addr.to_le_bytes());
buf[8..12].copy_from_slice(&len.to_le_bytes());
buf[12..14].copy_from_slice(&flags.to_le_bytes());
buf[14..16].copy_from_slice(&next.to_le_bytes());
mem.write_slice(&buf, GuestAddress(off))
.expect("plant descriptor");
}
fn publish_avail(mem: &GuestMemoryMmap, avail_base: u64, head_idx: u16, ring_pos: u16) {
let ring_off = avail_base + 4 + (ring_pos as u64) * 2;
mem.write_slice(&head_idx.to_le_bytes(), GuestAddress(ring_off))
.expect("write avail.ring entry");
let idx_off = avail_base + 2;
mem.write_slice(&(ring_pos + 1).to_le_bytes(), GuestAddress(idx_off))
.expect("write avail.idx");
}
fn plant_tx_chain(mem: &GuestMemoryMmap, descs: &[FuzzDesc]) {
for (i, d) in descs.iter().enumerate() {
write_desc(mem, TX_DESC_BASE, i as u16, d.addr, d.len, d.flags, d.next);
}
publish_avail(mem, TX_AVAIL_BASE, 0, 0);
}
fn plant_rx_chain(mem: &GuestMemoryMmap, descs: &[FuzzDesc]) {
for (i, d) in descs.iter().enumerate() {
write_desc(mem, RX_DESC_BASE, i as u16, d.addr, d.len, d.flags, d.next);
}
publish_avail(mem, RX_AVAIL_BASE, 0, 0);
}
fn plant_well_formed_rx_chain(mem: &GuestMemoryMmap) {
write_desc(
mem,
RX_DESC_BASE,
0,
RX_BUF_BASE,
2048,
VRING_DESC_F_WRITE,
0,
);
publish_avail(mem, RX_AVAIL_BASE, 0, 0);
}
#[derive(Default, Clone, Copy, Debug)]
struct CounterSnapshot {
tx_packets: u64,
tx_bytes: u64,
rx_packets: u64,
rx_bytes: u64,
tx_dropped_no_rx_buffer: u64,
tx_chain_invalid: u64,
rx_chain_invalid: u64,
rx_write_failed: u64,
tx_add_used_failures: u64,
rx_add_used_failures: u64,
invalid_avail_idx_count: u64,
}
fn snapshot_counters(dev: &VirtioNet) -> CounterSnapshot {
let c = dev.counters();
CounterSnapshot {
tx_packets: c.tx_packets.load(Ordering::Relaxed),
tx_bytes: c.tx_bytes.load(Ordering::Relaxed),
rx_packets: c.rx_packets.load(Ordering::Relaxed),
rx_bytes: c.rx_bytes.load(Ordering::Relaxed),
tx_dropped_no_rx_buffer: c.tx_dropped_no_rx_buffer.load(Ordering::Relaxed),
tx_chain_invalid: c.tx_chain_invalid.load(Ordering::Relaxed),
rx_chain_invalid: c.rx_chain_invalid.load(Ordering::Relaxed),
rx_write_failed: c.rx_write_failed.load(Ordering::Relaxed),
tx_add_used_failures: c.tx_add_used_failures.load(Ordering::Relaxed),
rx_add_used_failures: c.rx_add_used_failures.load(Ordering::Relaxed),
invalid_avail_idx_count: c.invalid_avail_idx_count.load(Ordering::Relaxed),
}
}
fn counter_delta(before: &CounterSnapshot, after: &CounterSnapshot) -> u64 {
(after.tx_packets - before.tx_packets)
+ (after.rx_packets - before.rx_packets)
+ (after.tx_dropped_no_rx_buffer - before.tx_dropped_no_rx_buffer)
+ (after.tx_chain_invalid - before.tx_chain_invalid)
+ (after.rx_chain_invalid - before.rx_chain_invalid)
+ (after.rx_write_failed - before.rx_write_failed)
+ (after.tx_add_used_failures - before.tx_add_used_failures)
+ (after.rx_add_used_failures - before.rx_add_used_failures)
+ (after.invalid_avail_idx_count - before.invalid_avail_idx_count)
}
fn assert_counter_monotonicity(
before: &CounterSnapshot,
after: &CounterSnapshot,
) -> Result<(), TestCaseError> {
prop_assert!(after.tx_packets >= before.tx_packets);
prop_assert!(after.tx_bytes >= before.tx_bytes);
prop_assert!(after.rx_packets >= before.rx_packets);
prop_assert!(after.rx_bytes >= before.rx_bytes);
prop_assert!(after.tx_dropped_no_rx_buffer >= before.tx_dropped_no_rx_buffer);
prop_assert!(after.tx_chain_invalid >= before.tx_chain_invalid);
prop_assert!(after.rx_chain_invalid >= before.rx_chain_invalid);
prop_assert!(after.rx_write_failed >= before.rx_write_failed);
prop_assert!(after.tx_add_used_failures >= before.tx_add_used_failures);
prop_assert!(after.rx_add_used_failures >= before.rx_add_used_failures);
prop_assert!(after.invalid_avail_idx_count >= before.invalid_avail_idx_count);
Ok(())
}
proptest! {
#![proptest_config(ProptestConfig {
cases: 256,
max_shrink_iters: 1024,
.. ProptestConfig::default()
})]
#[test]
fn tx_chain_progress_under_random_descriptors(
descs in fuzz_chain_strategy(),
) {
let (mut dev, mem) = build_fuzz_fixture();
plant_tx_chain(&mem, &descs);
plant_well_formed_rx_chain(&mem);
let before_tx_used = read_used_idx(&mem, TX_USED_BASE);
let before_rx_used = read_used_idx(&mem, RX_USED_BASE);
let before = snapshot_counters(&dev);
write_reg(&mut dev, VIRTIO_MMIO_QUEUE_NOTIFY, TXQ as u32);
let after_tx_used = read_used_idx(&mem, TX_USED_BASE);
let after_rx_used = read_used_idx(&mem, RX_USED_BASE);
let after = snapshot_counters(&dev);
assert_counter_monotonicity(&before, &after)?;
let used_delta = (after_tx_used - before_tx_used) as u64
+ (after_rx_used - before_rx_used) as u64;
let cdelta = counter_delta(&before, &after);
let progress = used_delta + cdelta;
prop_assert!(
progress >= 1,
"no visible progress: tx_used_delta={} rx_used_delta={} \
counter_delta={} (chain len={}, first_desc=({:#x},{},{:#x},{}))",
(after_tx_used - before_tx_used) as u64,
(after_rx_used - before_rx_used) as u64,
cdelta,
descs.len(),
descs[0].addr,
descs[0].len,
descs[0].flags,
descs[0].next,
);
}
#[test]
fn rx_chain_progress_under_random_descriptors(
rx_descs in fuzz_chain_strategy(),
) {
let (mut dev, mem) = build_fuzz_fixture();
let zero_hdr = [0u8; 12];
let payload: [u8; 16] = [0x42; 16];
mem.write_slice(&zero_hdr, GuestAddress(TX_FRAME_BUF_BASE))
.expect("plant zero header");
mem.write_slice(&payload, GuestAddress(TX_FRAME_BUF_BASE + 12))
.expect("plant payload");
let tx_total = (12 + payload.len()) as u32;
write_desc(&mem, TX_DESC_BASE, 0, TX_FRAME_BUF_BASE, tx_total, 0, 0);
publish_avail(&mem, TX_AVAIL_BASE, 0, 0);
plant_rx_chain(&mem, &rx_descs);
let before_tx_used = read_used_idx(&mem, TX_USED_BASE);
let before_rx_used = read_used_idx(&mem, RX_USED_BASE);
let before = snapshot_counters(&dev);
write_reg(&mut dev, VIRTIO_MMIO_QUEUE_NOTIFY, TXQ as u32);
let after_tx_used = read_used_idx(&mem, TX_USED_BASE);
let after_rx_used = read_used_idx(&mem, RX_USED_BASE);
let after = snapshot_counters(&dev);
assert_counter_monotonicity(&before, &after)?;
let tx_used_delta = (after_tx_used - before_tx_used) as u64;
let rx_used_delta = (after_rx_used - before_rx_used) as u64;
let cdelta = counter_delta(&before, &after);
prop_assert!(
tx_used_delta + rx_used_delta + cdelta >= 1,
"no visible progress: tx_used_delta={} \
rx_used_delta={} counter_delta={} \
(rx chain len={}, first_rx_desc=({:#x},{},{:#x},{}))",
tx_used_delta,
rx_used_delta,
cdelta,
rx_descs.len(),
rx_descs[0].addr,
rx_descs[0].len,
rx_descs[0].flags,
rx_descs[0].next,
);
}
#[test]
fn random_tx_header_either_loops_or_records_failure(
hdr_bytes in fuzz_header_strategy(),
) {
let (mut dev, mem) = build_fuzz_fixture();
mem.write_slice(&hdr_bytes, GuestAddress(TX_FRAME_BUF_BASE))
.expect("plant fuzzed header");
let payload: [u8; 16] = [0xAB; 16];
mem.write_slice(&payload, GuestAddress(TX_FRAME_BUF_BASE + 12))
.expect("plant payload");
let tx_total = (12 + payload.len()) as u32;
write_desc(&mem, TX_DESC_BASE, 0, TX_FRAME_BUF_BASE, tx_total, 0, 0);
publish_avail(&mem, TX_AVAIL_BASE, 0, 0);
plant_well_formed_rx_chain(&mem);
let before_tx_used = read_used_idx(&mem, TX_USED_BASE);
let before = snapshot_counters(&dev);
write_reg(&mut dev, VIRTIO_MMIO_QUEUE_NOTIFY, TXQ as u32);
let after_tx_used = read_used_idx(&mem, TX_USED_BASE);
let after = snapshot_counters(&dev);
assert_counter_monotonicity(&before, &after)?;
prop_assert_eq!(
after_tx_used - before_tx_used,
1,
"TX chain with valid shape and arbitrary header bytes \
must always complete TX add_used; header={:?}",
hdr_bytes,
);
prop_assert_eq!(
after.tx_packets - before.tx_packets,
1,
"well-formed TX chain must bump tx_packets exactly once",
);
}
#[test]
fn random_tx_desc_len_either_truncates_or_records_failure(
len in 0u32..(8u32 * 1024 * 1024),
) {
let (mut dev, mem) = build_fuzz_fixture();
let safe_fill_len = (len as usize).min(0x10_000); let zero_hdr = [0u8; 12];
mem.write_slice(&zero_hdr, GuestAddress(TX_FRAME_BUF_BASE))
.expect("plant zero header");
if safe_fill_len > 12 {
let filler = vec![0xBBu8; safe_fill_len - 12];
mem.write_slice(&filler, GuestAddress(TX_FRAME_BUF_BASE + 12))
.expect("plant filler");
}
write_desc(&mem, TX_DESC_BASE, 0, TX_FRAME_BUF_BASE, len, 0, 0);
publish_avail(&mem, TX_AVAIL_BASE, 0, 0);
plant_well_formed_rx_chain(&mem);
let before_tx_used = read_used_idx(&mem, TX_USED_BASE);
let before = snapshot_counters(&dev);
write_reg(&mut dev, VIRTIO_MMIO_QUEUE_NOTIFY, TXQ as u32);
let after_tx_used = read_used_idx(&mem, TX_USED_BASE);
let after = snapshot_counters(&dev);
assert_counter_monotonicity(&before, &after)?;
prop_assert_eq!(
after_tx_used - before_tx_used,
1,
"TX must advance used.idx by 1 per popped chain regardless of \
read failure; len={}",
len,
);
let tx_pkt_delta = after.tx_packets - before.tx_packets;
let tx_inv_delta = after.tx_chain_invalid - before.tx_chain_invalid;
prop_assert_eq!(
tx_pkt_delta + tx_inv_delta,
1,
"exactly one of tx_packets/tx_chain_invalid must bump per \
popped TX chain; len={} pkt_delta={} inv_delta={}",
len,
tx_pkt_delta,
tx_inv_delta,
);
}
#[test]
fn random_tx_desc_flags_either_loops_or_records_failure(
flags in 0u16..16,
) {
let (mut dev, mem) = build_fuzz_fixture();
let zero_hdr = [0u8; 12];
let payload: [u8; 16] = [0xCC; 16];
mem.write_slice(&zero_hdr, GuestAddress(TX_FRAME_BUF_BASE))
.expect("plant zero header");
mem.write_slice(&payload, GuestAddress(TX_FRAME_BUF_BASE + 12))
.expect("plant payload");
let tx_total = (12 + payload.len()) as u32;
write_desc(&mem, TX_DESC_BASE, 0, TX_FRAME_BUF_BASE, tx_total, flags, 0);
publish_avail(&mem, TX_AVAIL_BASE, 0, 0);
plant_well_formed_rx_chain(&mem);
let before_tx_used = read_used_idx(&mem, TX_USED_BASE);
let before = snapshot_counters(&dev);
write_reg(&mut dev, VIRTIO_MMIO_QUEUE_NOTIFY, TXQ as u32);
let after_tx_used = read_used_idx(&mem, TX_USED_BASE);
let after = snapshot_counters(&dev);
assert_counter_monotonicity(&before, &after)?;
let tx_used_delta = (after_tx_used - before_tx_used) as u64;
let cdelta = counter_delta(&before, &after);
prop_assert!(
tx_used_delta + cdelta >= 1,
"no progress with TX flags={:#x}: tx_used_delta={} \
counter_delta={}",
flags,
tx_used_delta,
cdelta,
);
}
#[test]
fn random_tx_next_link_either_loops_or_truncates(
next in any::<u16>(),
) {
let (mut dev, mem) = build_fuzz_fixture();
let zero_hdr = [0u8; 12];
let payload: [u8; 16] = [0xDD; 16];
mem.write_slice(&zero_hdr, GuestAddress(TX_FRAME_BUF_BASE))
.expect("plant zero header");
mem.write_slice(
&payload,
GuestAddress(TX_FRAME_BUF_BASE + 0x100),
)
.expect("plant payload");
write_desc(
&mem,
TX_DESC_BASE,
0,
TX_FRAME_BUF_BASE,
12,
VRING_DESC_F_NEXT,
next,
);
write_desc(
&mem,
TX_DESC_BASE,
1,
TX_FRAME_BUF_BASE + 0x100,
payload.len() as u32,
0,
0,
);
publish_avail(&mem, TX_AVAIL_BASE, 0, 0);
plant_well_formed_rx_chain(&mem);
let before_tx_used = read_used_idx(&mem, TX_USED_BASE);
let before = snapshot_counters(&dev);
write_reg(&mut dev, VIRTIO_MMIO_QUEUE_NOTIFY, TXQ as u32);
let after_tx_used = read_used_idx(&mem, TX_USED_BASE);
let after = snapshot_counters(&dev);
assert_counter_monotonicity(&before, &after)?;
prop_assert_eq!(
after_tx_used - before_tx_used,
1,
"TX must advance used.idx by 1 regardless of next link; next={}",
next,
);
let tx_pkt_delta = after.tx_packets - before.tx_packets;
let tx_inv_delta = after.tx_chain_invalid - before.tx_chain_invalid;
prop_assert_eq!(
tx_pkt_delta + tx_inv_delta,
1,
"exactly one of tx_packets/tx_chain_invalid must bump; \
next={} pkt_delta={} inv_delta={}",
next,
tx_pkt_delta,
tx_inv_delta,
);
}
}