use super::{
DiskThrottle, REQ_QUEUE, VIRTIO_BLK_OUTHDR_SIZE, VIRTIO_BLK_S_IOERR, VIRTIO_BLK_S_OK,
VIRTIO_BLK_S_UNSUPP, VIRTIO_BLK_T_FLUSH, VIRTIO_BLK_T_IN, VIRTIO_BLK_T_OUT,
VIRTIO_MMIO_QUEUE_NOTIFY, VirtioBlk, VirtioBlkOutHdr,
};
use proptest::prelude::*;
use std::num::NonZeroU64;
use std::os::unix::fs::FileExt;
use std::sync::atomic::Ordering;
use tempfile::tempfile;
use virtio_bindings::bindings::virtio_ring::{VRING_DESC_F_NEXT, VRING_DESC_F_WRITE};
use virtio_queue::QueueT;
use virtio_queue::desc::{RawDescriptor, split::Descriptor as SplitDescriptor};
use virtio_queue::mock::MockSplitQueue;
use vm_memory::{Address, Bytes, GuestAddress, GuestMemoryMmap};
#[derive(Debug, Clone, Copy)]
struct FuzzDesc {
addr: u64,
len: u32,
flags: u16,
next: u16,
}
fn fuzz_desc_strategy() -> impl Strategy<Value = FuzzDesc> {
(
0u64..(1u64 << 24),
0u32..(8 * 1024 * 1024),
0u16..8,
any::<u16>(),
)
.prop_map(|(addr, len, flags, next)| FuzzDesc {
addr,
len,
flags,
next,
})
}
fn fuzz_chain_strategy() -> impl Strategy<Value = Vec<FuzzDesc>> {
prop::collection::vec(fuzz_desc_strategy(), 1..=200)
}
fn build_fuzz_fixture() -> (VirtioBlk, GuestMemoryMmap) {
let cap = 4096u64;
let f = tempfile().expect("create tempfile for fuzz backing");
f.set_len(cap).expect("set tempfile length to fuzz cap");
f.write_at(&[0xAB; 4096], 0).expect("seed backing pattern");
let dev = VirtioBlk::new(f, cap, DiskThrottle::default());
let mem = GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 1 << 20)])
.expect("create proptest guest mem");
(dev, mem)
}
fn wire_fuzz_device(dev: &mut VirtioBlk, mock: &MockSplitQueue<GuestMemoryMmap>) {
use super::{
QUEUE_MAX_SIZE, S_ACK, S_DRV, S_FEAT, S_OK, VIRTIO_MMIO_DRIVER_FEATURES,
VIRTIO_MMIO_DRIVER_FEATURES_SEL, VIRTIO_MMIO_QUEUE_AVAIL_HIGH, VIRTIO_MMIO_QUEUE_AVAIL_LOW,
VIRTIO_MMIO_QUEUE_DESC_HIGH, VIRTIO_MMIO_QUEUE_DESC_LOW, VIRTIO_MMIO_QUEUE_NUM,
VIRTIO_MMIO_QUEUE_READY, VIRTIO_MMIO_QUEUE_SEL, VIRTIO_MMIO_QUEUE_USED_HIGH,
VIRTIO_MMIO_QUEUE_USED_LOW, VIRTIO_MMIO_STATUS,
};
use virtio_bindings::virtio_config::VIRTIO_F_VERSION_1;
let write_reg = |dev: &mut VirtioBlk, offset: u32, val: u32| {
dev.mmio_write(offset as u64, &val.to_le_bytes());
};
write_reg(dev, VIRTIO_MMIO_STATUS, S_ACK);
write_reg(dev, VIRTIO_MMIO_STATUS, S_DRV);
write_reg(dev, VIRTIO_MMIO_DRIVER_FEATURES_SEL, 1);
write_reg(
dev,
VIRTIO_MMIO_DRIVER_FEATURES,
1 << (VIRTIO_F_VERSION_1 - 32),
);
write_reg(dev, VIRTIO_MMIO_STATUS, S_FEAT);
write_reg(dev, VIRTIO_MMIO_QUEUE_SEL, 0);
write_reg(dev, VIRTIO_MMIO_QUEUE_NUM, QUEUE_MAX_SIZE as u32);
let desc = mock.desc_table_addr().0;
let avail = mock.avail_addr().0;
let used = mock.used_addr().0;
write_reg(dev, VIRTIO_MMIO_QUEUE_DESC_LOW, desc as u32);
write_reg(dev, VIRTIO_MMIO_QUEUE_DESC_HIGH, (desc >> 32) as u32);
write_reg(dev, VIRTIO_MMIO_QUEUE_AVAIL_LOW, avail as u32);
write_reg(dev, VIRTIO_MMIO_QUEUE_AVAIL_HIGH, (avail >> 32) as u32);
write_reg(dev, VIRTIO_MMIO_QUEUE_USED_LOW, used as u32);
write_reg(dev, VIRTIO_MMIO_QUEUE_USED_HIGH, (used >> 32) as u32);
write_reg(dev, VIRTIO_MMIO_QUEUE_READY, 1);
write_reg(dev, VIRTIO_MMIO_STATUS, S_OK);
}
fn read_used_idx(mem: &GuestMemoryMmap, mock: &MockSplitQueue<GuestMemoryMmap>) -> u16 {
mem.read_obj::<u16>(mock.used_addr().checked_add(2).unwrap())
.expect("read used.idx")
}
#[derive(Default, Clone, Copy)]
struct CounterSnapshot {
reads: u64,
writes: u64,
flushes: u64,
bytes_read: u64,
bytes_written: u64,
throttled: u64,
io_errors: u64,
}
fn snapshot_counters(dev: &VirtioBlk) -> CounterSnapshot {
let c = dev.counters();
CounterSnapshot {
reads: c.reads_completed.load(Ordering::Relaxed),
writes: c.writes_completed.load(Ordering::Relaxed),
flushes: c.flushes_completed.load(Ordering::Relaxed),
bytes_read: c.bytes_read.load(Ordering::Relaxed),
bytes_written: c.bytes_written.load(Ordering::Relaxed),
throttled: c.throttled_count.load(Ordering::Relaxed),
io_errors: c.io_errors.load(Ordering::Relaxed),
}
}
fn build_throttled_fuzz_fixture() -> (VirtioBlk, GuestMemoryMmap) {
let cap = 4096u64;
let f = tempfile().expect("create tempfile for throttled fuzz backing");
f.set_len(cap).expect("set tempfile length to fuzz cap");
f.write_at(&[0xAB; 4096], 0).expect("seed backing pattern");
let throttle = DiskThrottle {
iops: NonZeroU64::new(1),
bytes_per_sec: None,
iops_burst_capacity: None,
bytes_burst_capacity: None,
};
let mut dev = VirtioBlk::new(f, cap, throttle);
let now = std::time::Instant::now();
dev.worker
.state_mut()
.ops_bucket
.set_last_refill_for_test(now);
assert!(dev.worker.state_mut().ops_bucket.consume(1));
dev.worker
.state_mut()
.ops_bucket
.set_last_refill_for_test(now);
let mem = GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 1 << 20)])
.expect("create proptest guest mem");
(dev, mem)
}
#[derive(Debug, Clone)]
struct WellFormedChain {
req_type: u32,
sector: u64,
n_data_segments: u32,
seg_sectors: u32,
}
fn well_formed_chain_strategy() -> impl Strategy<Value = WellFormedChain> {
let req_type = prop_oneof![
Just(VIRTIO_BLK_T_IN),
Just(VIRTIO_BLK_T_OUT),
Just(VIRTIO_BLK_T_FLUSH),
];
(req_type, 0u64..8u64, 1u32..=8u32, 1u32..=4u32).prop_map(
|(req_type, sector, n_data_segments, seg_sectors)| WellFormedChain {
req_type,
sector,
n_data_segments,
seg_sectors,
},
)
}
fn plant_well_formed_chain(
mem: &GuestMemoryMmap,
mock: &MockSplitQueue<GuestMemoryMmap>,
chain: &WellFormedChain,
) -> GuestAddress {
let header_addr = GuestAddress(0x4000);
let status_addr = GuestAddress(0xC000);
let hdr = VirtioBlkOutHdr {
type_: chain.req_type,
_ioprio: 0,
sector: chain.sector,
};
mem.write_obj(hdr, header_addr).expect("plant header");
mem.write_slice(&[0xEEu8], status_addr)
.expect("plant status sentinel");
let mut descs: Vec<RawDescriptor> = Vec::new();
let header_link_to = if chain.req_type == VIRTIO_BLK_T_FLUSH {
1u16
} else {
1u16
};
descs.push(RawDescriptor::from(SplitDescriptor::new(
header_addr.0,
VIRTIO_BLK_OUTHDR_SIZE as u32,
VRING_DESC_F_NEXT as u16,
header_link_to,
)));
if chain.req_type != VIRTIO_BLK_T_FLUSH {
let max_seg_count = (8u32)
.saturating_div(chain.seg_sectors)
.max(1)
.min(chain.n_data_segments);
let data_flag = if chain.req_type == VIRTIO_BLK_T_IN {
VRING_DESC_F_WRITE as u16
} else {
0u16
};
for i in 0..max_seg_count {
let seg_addr = 0x5000u64 + (i as u64 * 0x800);
let seg_len = chain.seg_sectors * 512;
let next_idx = i + 2; descs.push(RawDescriptor::from(SplitDescriptor::new(
seg_addr,
seg_len,
data_flag | VRING_DESC_F_NEXT as u16,
next_idx as u16,
)));
}
}
descs.push(RawDescriptor::from(SplitDescriptor::new(
status_addr.0,
1,
VRING_DESC_F_WRITE as u16,
0,
)));
mock.build_desc_chain(&descs).expect("build chain");
status_addr
}
proptest! {
#![proptest_config(ProptestConfig {
cases: 256,
max_shrink_iters: 1024,
.. ProptestConfig::default()
})]
#[test]
fn process_requests_progress_under_random_chains(
descs in fuzz_chain_strategy(),
) {
let (mut dev, mem) = build_fuzz_fixture();
let mock = MockSplitQueue::create(&mem, GuestAddress(0), 256);
dev.set_mem(mem.clone());
wire_fuzz_device(&mut dev, &mock);
let raw_descs: Vec<RawDescriptor> = descs
.iter()
.map(|d| {
RawDescriptor::from(SplitDescriptor::new(
d.addr,
d.len,
d.flags,
d.next,
))
})
.collect();
mock.add_desc_chains(&raw_descs, 0)
.expect("plant descriptors into avail ring");
let before_used = read_used_idx(&mem, &mock);
let before = snapshot_counters(&dev);
dev.mmio_write(
VIRTIO_MMIO_QUEUE_NOTIFY as u64,
&(REQ_QUEUE as u32).to_le_bytes(),
);
let after_used = read_used_idx(&mem, &mock);
let after = snapshot_counters(&dev);
prop_assert!(after.reads >= before.reads);
prop_assert!(after.writes >= before.writes);
prop_assert!(after.flushes >= before.flushes);
prop_assert!(after.bytes_read >= before.bytes_read);
prop_assert!(after.bytes_written >= before.bytes_written);
prop_assert!(after.throttled >= before.throttled);
prop_assert!(after.io_errors >= before.io_errors);
prop_assert!(after_used >= before_used);
let used_delta = (after_used - before_used) as u64;
let counter_delta = (after.reads - before.reads)
+ (after.writes - before.writes)
+ (after.flushes - before.flushes)
+ (after.throttled - before.throttled)
+ (after.io_errors - before.io_errors);
let progress = used_delta + counter_delta;
prop_assert!(
progress >= 1,
"no visible progress: used_delta={} counter_delta={} \
(chain len={}, first_desc=({:#x},{},{:#x},{}))",
used_delta,
counter_delta,
descs.len(),
descs[0].addr,
descs[0].len,
descs[0].flags,
descs[0].next,
);
}
#[test]
fn random_header_addr_either_succeeds_or_ioerrs(
header_addr_low in 0u64..(1u64 << 24),
) {
let (mut dev, mem) = build_fuzz_fixture();
let mock = MockSplitQueue::create(&mem, GuestAddress(0), 256);
dev.set_mem(mem.clone());
wire_fuzz_device(&mut dev, &mock);
let status_addr = GuestAddress(0x6000);
mem.write_slice(&[0xEEu8], status_addr).unwrap();
let descs = [
RawDescriptor::from(SplitDescriptor::new(
header_addr_low,
VIRTIO_BLK_OUTHDR_SIZE as u32,
0, 1,
)),
RawDescriptor::from(SplitDescriptor::new(
status_addr.0,
1,
VRING_DESC_F_WRITE as u16,
0,
)),
];
mock.build_desc_chain(&descs).expect("build chain");
dev.mmio_write(
VIRTIO_MMIO_QUEUE_NOTIFY as u64,
&(REQ_QUEUE as u32).to_le_bytes(),
);
let mut s = [0u8; 1];
mem.read_slice(&mut s, status_addr).unwrap();
prop_assert!(
s[0] == VIRTIO_BLK_S_OK as u8
|| s[0] == VIRTIO_BLK_S_IOERR as u8
|| s[0] == VIRTIO_BLK_S_UNSUPP as u8,
"status byte {:#x} is not a valid virtio-blk status",
s[0],
);
let used_idx = read_used_idx(&mem, &mock);
prop_assert_eq!(
used_idx,
1,
"well-formed chain shape with random header_addr must \
produce exactly one used-ring entry; got {}",
used_idx,
);
}
#[test]
fn random_data_len_either_succeeds_or_ioerrs(
data_len in 0u32..(8u32 * 1024 * 1024),
req_type in 0u32..=8u32,
) {
let (mut dev, mem) = build_fuzz_fixture();
let mock = MockSplitQueue::create(&mem, GuestAddress(0), 256);
dev.set_mem(mem.clone());
wire_fuzz_device(&mut dev, &mock);
let header_addr = GuestAddress(0x4000);
let data_addr = GuestAddress(0x5000);
let status_addr = GuestAddress(0x6000);
let hdr = VirtioBlkOutHdr {
type_: req_type,
_ioprio: 0,
sector: 0,
};
mem.write_obj(hdr, header_addr).expect("plant header");
mem.write_slice(&[0xEEu8], status_addr).unwrap();
let data_flags = if req_type == 1 {
0
} else {
VRING_DESC_F_WRITE as u16
};
let descs = [
RawDescriptor::from(SplitDescriptor::new(
header_addr.0,
VIRTIO_BLK_OUTHDR_SIZE as u32,
0,
1,
)),
RawDescriptor::from(SplitDescriptor::new(
data_addr.0,
data_len,
data_flags,
2,
)),
RawDescriptor::from(SplitDescriptor::new(
status_addr.0,
1,
VRING_DESC_F_WRITE as u16,
0,
)),
];
mock.build_desc_chain(&descs).expect("build chain");
dev.mmio_write(
VIRTIO_MMIO_QUEUE_NOTIFY as u64,
&(REQ_QUEUE as u32).to_le_bytes(),
);
let mut s = [0u8; 1];
mem.read_slice(&mut s, status_addr).unwrap();
prop_assert!(
s[0] == VIRTIO_BLK_S_OK as u8
|| s[0] == VIRTIO_BLK_S_IOERR as u8
|| s[0] == VIRTIO_BLK_S_UNSUPP as u8,
"status byte {:#x} is not a valid virtio-blk status",
s[0],
);
let used_idx = read_used_idx(&mem, &mock);
prop_assert_eq!(
used_idx,
1,
"fuzzed data_len chain must produce exactly one \
used-ring entry; got {}",
used_idx,
);
}
#[test]
fn random_flags_either_succeeds_or_ioerrs(
data_flags in 0u16..16,
) {
let (mut dev, mem) = build_fuzz_fixture();
let mock = MockSplitQueue::create(&mem, GuestAddress(0), 256);
dev.set_mem(mem.clone());
wire_fuzz_device(&mut dev, &mock);
let header_addr = GuestAddress(0x4000);
let data_addr = GuestAddress(0x5000);
let status_addr = GuestAddress(0x6000);
let hdr = VirtioBlkOutHdr {
type_: super::VIRTIO_BLK_T_IN,
_ioprio: 0,
sector: 0,
};
mem.write_obj(hdr, header_addr).expect("plant header");
mem.write_slice(&[0xEEu8], status_addr).unwrap();
let descs = [
RawDescriptor::from(SplitDescriptor::new(
header_addr.0,
VIRTIO_BLK_OUTHDR_SIZE as u32,
VRING_DESC_F_NEXT as u16,
1,
)),
RawDescriptor::from(SplitDescriptor::new(
data_addr.0,
512,
data_flags | VRING_DESC_F_NEXT as u16,
2,
)),
RawDescriptor::from(SplitDescriptor::new(
status_addr.0,
1,
VRING_DESC_F_WRITE as u16,
0,
)),
];
mock.add_desc_chains(&descs, 0).expect("plant descriptors");
let before_used = read_used_idx(&mem, &mock);
let before = snapshot_counters(&dev);
dev.mmio_write(
VIRTIO_MMIO_QUEUE_NOTIFY as u64,
&(REQ_QUEUE as u32).to_le_bytes(),
);
let after_used = read_used_idx(&mem, &mock);
let after = snapshot_counters(&dev);
let used_delta = (after_used - before_used) as u64;
let counter_delta = (after.reads - before.reads)
+ (after.writes - before.writes)
+ (after.flushes - before.flushes)
+ (after.throttled - before.throttled)
+ (after.io_errors - before.io_errors);
prop_assert!(
used_delta + counter_delta >= 1,
"no progress with data_flags={:#x}: \
used_delta={} counter_delta={}",
data_flags,
used_delta,
counter_delta,
);
}
#[test]
fn throttle_stall_under_random_chain_shapes_holds_invariants(
chain in well_formed_chain_strategy(),
) {
let (mut dev, mem) = build_throttled_fuzz_fixture();
let mock = MockSplitQueue::create(&mem, GuestAddress(0), 256);
dev.set_mem(mem.clone());
wire_fuzz_device(&mut dev, &mock);
let status_addr = plant_well_formed_chain(&mem, &mock, &chain);
let before = snapshot_counters(&dev);
dev.mmio_write(
VIRTIO_MMIO_QUEUE_NOTIFY as u64,
&(REQ_QUEUE as u32).to_le_bytes(),
);
let after = snapshot_counters(&dev);
prop_assert!(after.reads >= before.reads);
prop_assert!(after.writes >= before.writes);
prop_assert!(after.flushes >= before.flushes);
prop_assert!(after.bytes_read >= before.bytes_read);
prop_assert!(after.bytes_written >= before.bytes_written);
prop_assert!(after.throttled >= before.throttled);
prop_assert!(after.io_errors >= before.io_errors);
let throttled_delta = after.throttled - before.throttled;
let io_errors_delta = after.io_errors - before.io_errors;
prop_assert!(
throttled_delta + io_errors_delta >= 1,
"drained throttle must produce a stall or pre-throttle reject; \
throttled_delta={throttled_delta} io_errors_delta={io_errors_delta} \
chain={chain:?}",
);
prop_assert_eq!(
after.reads - before.reads, 0,
"drained throttle must not produce a successful read"
);
prop_assert_eq!(
after.writes - before.writes, 0,
"drained throttle must not produce a successful write"
);
prop_assert_eq!(
after.flushes - before.flushes, 0,
"drained throttle must not produce a successful flush"
);
if throttled_delta == 1 {
let mut s = [0u8; 1];
mem.read_slice(&mut s, status_addr)
.expect("read status sentinel");
prop_assert_eq!(
s[0], 0xEE,
"stalled chain must not write status byte; chain={:?}",
chain,
);
let post_stall_next_avail = dev.worker.queues[REQ_QUEUE].next_avail();
prop_assert_eq!(
post_stall_next_avail, 0u16,
"post-stall next_avail must rewind to 0; got {}",
post_stall_next_avail,
);
let gauge = dev.counters().currently_throttled_gauge.load(Ordering::Relaxed);
prop_assert_eq!(
gauge, 1,
"stalled-chain gauge must show 1 (false→true transition)",
);
}
if io_errors_delta >= 1 && throttled_delta == 0 {
let mut s = [0u8; 1];
mem.read_slice(&mut s, status_addr)
.expect("read status byte");
prop_assert!(
s[0] == VIRTIO_BLK_S_IOERR as u8 || s[0] == VIRTIO_BLK_S_OK as u8
|| s[0] == VIRTIO_BLK_S_UNSUPP as u8,
"pre-throttle reject must write a defined virtio-blk status; \
got status={:#x} chain={:?}",
s[0], chain,
);
}
}
}