#![cfg(test)]
use super::testing::*;
use super::*;
use std::sync::atomic::Ordering;
use std::thread;
use std::time::{Duration, Instant};
use tempfile::tempfile;
use virtio_bindings::bindings::virtio_ring::VRING_DESC_F_WRITE;
use virtio_queue::desc::{RawDescriptor, split::Descriptor as SplitDescriptor};
use virtio_queue::mock::MockSplitQueue;
fn dummy_worker_state() -> BlkWorkerState {
BlkWorkerState {
backing: tempfile().expect("create tempfile for dummy_worker_state"),
ops_bucket: TokenBucket::unlimited(),
bytes_bucket: TokenBucket::unlimited(),
all_descs_scratch: Vec::new(),
io_buf_scratch: Vec::new(),
capacity_bytes: 0,
read_only: false,
counters: Arc::new(VirtioBlkCounters::default()),
currently_stalled: false,
queue_poisoned: false,
}
}
#[test]
fn join_worker_with_timeout_happy_path_returns_joined() {
let handle = std::thread::Builder::new()
.name("ktstr-vblk-test-happy".to_string())
.spawn(dummy_worker_state)
.expect("spawn happy-path worker");
let start = Instant::now();
let outcome = join_worker_with_timeout(handle, DROP_JOIN_TIMEOUT);
let elapsed = start.elapsed();
assert!(
matches!(outcome, JoinWithTimeoutOutcome::Joined(_)),
"expected Joined, got {:?}",
outcome_label(&outcome)
);
assert!(
elapsed < Duration::from_millis(100),
"happy-path join took {elapsed:?}, expected < 100ms"
);
}
#[test]
fn join_worker_with_timeout_returns_timed_out_when_worker_blocks() {
let handle = std::thread::Builder::new()
.name("ktstr-vblk-test-timeout".to_string())
.spawn(|| {
std::thread::sleep(Duration::from_secs(60));
dummy_worker_state()
})
.expect("spawn timeout-path worker");
let start = Instant::now();
let outcome = join_worker_with_timeout(handle, Duration::from_millis(50));
let elapsed = start.elapsed();
assert!(
matches!(outcome, JoinWithTimeoutOutcome::TimedOut),
"expected TimedOut, got {:?}",
outcome_label(&outcome)
);
assert!(
elapsed >= Duration::from_millis(50),
"timeout fired too early at {elapsed:?}; expected >= 50ms"
);
assert!(
elapsed < Duration::from_millis(200),
"timeout fired too late at {elapsed:?}; expected < 200ms \
(recv_timeout overhead budget)"
);
}
#[test]
fn join_worker_with_timeout_returns_panicked_on_worker_panic() {
let handle = std::thread::Builder::new()
.name("ktstr-vblk-test-panic".to_string())
.spawn(|| -> BlkWorkerState {
panic!("intentional panic from join_worker_with_timeout test");
})
.expect("spawn panic-path worker");
let start = Instant::now();
let outcome = join_worker_with_timeout(handle, DROP_JOIN_TIMEOUT);
let elapsed = start.elapsed();
assert!(
matches!(outcome, JoinWithTimeoutOutcome::Panicked(_)),
"expected Panicked, got {:?}",
outcome_label(&outcome)
);
assert!(
elapsed < Duration::from_millis(100),
"panic-path join took {elapsed:?}, expected < 100ms \
(parity with happy path)"
);
if let JoinWithTimeoutOutcome::Panicked(payload) = outcome {
assert_eq!(
panic_payload_str(&*payload),
"intentional panic from join_worker_with_timeout test",
"panic payload round-trip should preserve the &'static str"
);
}
}
fn outcome_label(o: &JoinWithTimeoutOutcome) -> &'static str {
match o {
JoinWithTimeoutOutcome::Joined(_) => "Joined",
JoinWithTimeoutOutcome::Panicked(_) => "Panicked",
JoinWithTimeoutOutcome::TimedOut => "TimedOut",
JoinWithTimeoutOutcome::HelperSpawnFailed => "HelperSpawnFailed",
JoinWithTimeoutOutcome::HelperDisconnected => "HelperDisconnected",
}
}
#[test]
fn reset_join_timeout_matches_drop_budget() {
assert_eq!(
RESET_JOIN_TIMEOUT, DROP_JOIN_TIMEOUT,
"RESET_JOIN_TIMEOUT must equal DROP_JOIN_TIMEOUT — both \
paths run on a vCPU thread that the freeze coordinator \
may target with SIGRTMIN; asymmetric budgets would let \
reset() miss a rendezvous Drop wouldn't, or vice versa",
);
assert_eq!(RESET_JOIN_TIMEOUT, Duration::from_secs(1));
}
#[test]
fn reset_join_timeout_against_wedged_worker_returns_timed_out() {
use std::sync::mpsc as test_mpsc;
let (_keep_alive_tx, wedge_rx) = test_mpsc::channel::<()>();
let handle = std::thread::Builder::new()
.name("ktstr-vblk-test-wedged-reset".to_string())
.spawn(move || -> BlkWorkerState {
let _ = wedge_rx.recv();
dummy_worker_state()
})
.expect("spawn wedged worker");
const TEST_TIMEOUT: Duration = Duration::from_millis(100);
assert!(
TEST_TIMEOUT < RESET_JOIN_TIMEOUT,
"test budget must be smaller than RESET_JOIN_TIMEOUT \
so the test stays fast; a future RESET_JOIN_TIMEOUT \
tightening below 100 ms would require updating \
TEST_TIMEOUT here",
);
let start = Instant::now();
let outcome = join_worker_with_timeout(handle, TEST_TIMEOUT);
let elapsed = start.elapsed();
assert!(
matches!(outcome, JoinWithTimeoutOutcome::TimedOut),
"wedged worker must yield TimedOut, got {:?}",
outcome_label(&outcome)
);
assert!(
elapsed < TEST_TIMEOUT * 2,
"join_worker_with_timeout took {elapsed:?} for a \
wedged worker (budget {TEST_TIMEOUT:?}); the bound \
must hold so the production reset() path doesn't \
pin the vCPU thread when the worker is stuck"
);
}
#[test]
fn interrupt_status_concurrent_fetch_or_load() {
use std::sync::Barrier;
let dev = make_device(VIRTIO_BLK_DEFAULT_CAPACITY_BYTES, DiskThrottle::default());
let int_status = Arc::clone(&dev.interrupt_status);
const NUM_WRITERS: u32 = 16;
let barrier = Arc::new(Barrier::new(NUM_WRITERS as usize + 1));
let mut handles = Vec::with_capacity(NUM_WRITERS as usize);
for bit in 0..NUM_WRITERS {
let int_status_w = Arc::clone(&int_status);
let barrier_w = Arc::clone(&barrier);
handles.push(thread::spawn(move || {
barrier_w.wait();
for _ in 0..1_000 {
int_status_w.fetch_or(1u32 << bit, Ordering::Release);
}
}));
}
barrier.wait();
for h in handles {
h.join().expect("writer thread join");
}
let expected_union = (1u32 << NUM_WRITERS) - 1;
let observed = int_status.load(Ordering::Acquire);
assert_eq!(
observed, expected_union,
"all NUM_WRITERS bits must be set; missing bits indicate \
a lost fetch_or update — observed {observed:#x}, \
expected {expected_union:#x}",
);
}
#[test]
fn interrupt_status_concurrent_set_and_ack() {
use std::sync::Barrier;
let dev = make_device(VIRTIO_BLK_DEFAULT_CAPACITY_BYTES, DiskThrottle::default());
let int_status = Arc::clone(&dev.interrupt_status);
const BIT_X: u32 = 1 << 0;
const BIT_Y: u32 = 1 << 1;
int_status.store(BIT_Y, Ordering::Release);
let barrier = Arc::new(Barrier::new(3));
let int_status_a = Arc::clone(&int_status);
let barrier_a = Arc::clone(&barrier);
let setter = thread::spawn(move || {
barrier_a.wait();
for _ in 0..10_000 {
int_status_a.fetch_or(BIT_X, Ordering::Release);
}
});
let int_status_b = Arc::clone(&int_status);
let barrier_b = Arc::clone(&barrier);
let acker = thread::spawn(move || {
barrier_b.wait();
for _ in 0..10_000 {
int_status_b.fetch_and(!BIT_Y, Ordering::AcqRel);
}
});
barrier.wait();
setter.join().expect("setter join");
acker.join().expect("acker join");
let final_state = int_status.load(Ordering::Acquire);
assert_eq!(
final_state & BIT_X,
BIT_X,
"bit X must remain set after the race — fetch_or sets and \
fetch_and(!Y) is disjoint; if X is missing, fetch_and \
accidentally cleared it (atomicity violation)",
);
assert_eq!(
final_state & BIT_Y,
0,
"bit Y must be clear after the race — every iteration of \
thread B issues fetch_and(!Y); if Y is set, fetch_and \
missed an iteration (lost update)",
);
}
#[test]
fn config_generation_concurrent_fetch_add_load() {
use std::sync::Barrier;
let dev = make_device(VIRTIO_BLK_DEFAULT_CAPACITY_BYTES, DiskThrottle::default());
let initial = dev.config_generation.load(Ordering::Acquire);
let counter = Arc::new(AtomicU32::new(initial));
const NUM_WRITERS: u32 = 16;
const ITERATIONS_PER_WRITER: u32 = 1_000;
let barrier = Arc::new(Barrier::new(NUM_WRITERS as usize + 1));
let mut handles = Vec::with_capacity(NUM_WRITERS as usize);
for _ in 0..NUM_WRITERS {
let counter_w = Arc::clone(&counter);
let barrier_w = Arc::clone(&barrier);
handles.push(thread::spawn(move || {
barrier_w.wait();
for _ in 0..ITERATIONS_PER_WRITER {
counter_w.fetch_add(1, Ordering::Release);
}
}));
}
barrier.wait();
for h in handles {
h.join().expect("writer join");
}
let final_value = counter.load(Ordering::Acquire);
let expected = initial.wrapping_add(NUM_WRITERS * ITERATIONS_PER_WRITER);
assert_eq!(
final_value, expected,
"fetch_add atomicity violated: expected {expected}, got \
{final_value} (lost updates means the counter advanced \
less than NUM_WRITERS * ITERATIONS_PER_WRITER)",
);
}
#[test]
fn counters_concurrent_fetch_add_no_lost_updates() {
use std::sync::Barrier;
let counters = Arc::new(VirtioBlkCounters::default());
const NUM_WRITERS: u32 = 8;
const ITERATIONS_PER_WRITER: u32 = 5_000;
let barrier = Arc::new(Barrier::new(NUM_WRITERS as usize + 2));
let mut handles = Vec::with_capacity(NUM_WRITERS as usize);
for _ in 0..NUM_WRITERS {
let c_w = Arc::clone(&counters);
let barrier_w = Arc::clone(&barrier);
handles.push(thread::spawn(move || {
barrier_w.wait();
for _ in 0..ITERATIONS_PER_WRITER {
c_w.record_read(512);
c_w.record_write(1024);
c_w.record_flush();
c_w.record_throttled();
c_w.record_io_error();
}
}));
}
let c_reader = Arc::clone(&counters);
let barrier_r = Arc::clone(&barrier);
let reader = thread::spawn(move || {
barrier_r.wait();
let mut last_reads = 0u64;
for _ in 0..1_000 {
let now_reads = c_reader.reads_completed.load(Ordering::Relaxed);
assert!(
now_reads >= last_reads,
"reads_completed went backwards: {last_reads} -> {now_reads}",
);
last_reads = now_reads;
}
});
barrier.wait();
for h in handles {
h.join().expect("writer join");
}
reader.join().expect("reader join");
let total_iters = (NUM_WRITERS * ITERATIONS_PER_WRITER) as u64;
assert_eq!(
counters.reads_completed.load(Ordering::Relaxed),
total_iters,
"reads_completed lost an update",
);
assert_eq!(
counters.bytes_read.load(Ordering::Relaxed),
total_iters * 512,
"bytes_read lost an update",
);
assert_eq!(
counters.writes_completed.load(Ordering::Relaxed),
total_iters,
"writes_completed lost an update",
);
assert_eq!(
counters.bytes_written.load(Ordering::Relaxed),
total_iters * 1024,
"bytes_written lost an update",
);
assert_eq!(
counters.flushes_completed.load(Ordering::Relaxed),
total_iters,
"flushes_completed lost an update",
);
assert_eq!(
counters.throttled_count.load(Ordering::Relaxed),
total_iters,
"throttled_count lost an update",
);
assert_eq!(
counters.io_errors.load(Ordering::Relaxed),
total_iters,
"io_errors lost an update",
);
}
#[test]
fn interrupt_status_is_arc_shareable() {
let dev = make_device(VIRTIO_BLK_DEFAULT_CAPACITY_BYTES, DiskThrottle::default());
let cloned = Arc::clone(&dev.interrupt_status);
assert!(
Arc::strong_count(&cloned) >= 2,
"interrupt_status must be Arc-shareable — strong_count \
after clone is {}",
Arc::strong_count(&cloned),
);
}
#[test]
fn currently_throttled_gauge_increments_on_first_stall() {
let mem = make_chain_test_mem();
let mut dev = setup_iops1_drained_chain(&mem);
let c = dev.counters();
assert_eq!(
c.currently_throttled_gauge.load(Ordering::Relaxed),
0,
"fresh device must have currently_throttled_gauge=0",
);
write_reg(&mut dev, VIRTIO_MMIO_QUEUE_NOTIFY, REQ_QUEUE as u32);
assert_eq!(
c.currently_throttled_gauge.load(Ordering::Relaxed),
1,
"first stall must bump currently_throttled_gauge from 0 to 1",
);
assert_eq!(
c.throttled_count.load(Ordering::Relaxed),
1,
"first stall bumps throttled_count to 1",
);
assert!(
dev.worker.state().currently_stalled,
"BlkWorkerState::currently_stalled must be true after stall",
);
}
#[test]
fn currently_throttled_gauge_decrements_on_retry_success() {
let mem = make_chain_test_mem();
let mut dev = setup_iops1_drained_chain(&mem);
write_reg(&mut dev, VIRTIO_MMIO_QUEUE_NOTIFY, REQ_QUEUE as u32);
let c = dev.counters();
assert_eq!(c.currently_throttled_gauge.load(Ordering::Relaxed), 1);
dev.worker
.state_mut()
.ops_bucket
.set_last_refill_for_test(std::time::Instant::now() - std::time::Duration::from_secs(2));
write_reg(&mut dev, VIRTIO_MMIO_QUEUE_NOTIFY, REQ_QUEUE as u32);
assert_eq!(
c.currently_throttled_gauge.load(Ordering::Relaxed),
0,
"retry success must decrement currently_throttled_gauge to 0",
);
assert!(
!dev.worker.state().currently_stalled,
"BlkWorkerState::currently_stalled must clear on retry success",
);
assert_eq!(
c.throttled_count.load(Ordering::Relaxed),
1,
"throttled_count is per-event; retry success doesn't bump it",
);
assert_eq!(c.reads_completed.load(Ordering::Relaxed), 1);
}
#[test]
fn currently_throttled_gauge_no_double_inc_on_re_stall() {
let mem = make_chain_test_mem();
mem.write_slice(&[0xEEu8], GuestAddress(0x6000)).unwrap();
let mut dev = setup_iops1_drained_chain(&mem);
write_reg(&mut dev, VIRTIO_MMIO_QUEUE_NOTIFY, REQ_QUEUE as u32);
dev.worker
.state_mut()
.ops_bucket
.set_last_refill_for_test(std::time::Instant::now());
write_reg(&mut dev, VIRTIO_MMIO_QUEUE_NOTIFY, REQ_QUEUE as u32);
let c = dev.counters();
assert_eq!(
c.throttled_count.load(Ordering::Relaxed),
2,
"two stalls bump throttled_count twice (events)",
);
assert_eq!(
c.currently_throttled_gauge.load(Ordering::Relaxed),
1,
"two stalls on same head must NOT double-increment the \
gauge — gauge represents one stuck request, not two \
stall events",
);
assert!(
dev.worker.state().currently_stalled,
"currently_stalled flag stays true across re-stall",
);
}
#[test]
fn reset_decrements_pending_throttle_gauge() {
let mem = make_chain_test_mem();
let mut dev = setup_iops1_drained_chain(&mem);
write_reg(&mut dev, VIRTIO_MMIO_QUEUE_NOTIFY, REQ_QUEUE as u32);
let c = dev.counters();
assert_eq!(c.currently_throttled_gauge.load(Ordering::Relaxed), 1);
write_reg(&mut dev, VIRTIO_MMIO_STATUS, 0);
assert_eq!(
c.currently_throttled_gauge.load(Ordering::Relaxed),
0,
"reset must decrement currently_throttled_gauge so a \
reset-while-stalled does not leak a pending increment",
);
assert!(
!dev.worker.state().currently_stalled,
"reset must clear currently_stalled",
);
}
#[test]
fn reset_on_non_stalled_device_leaves_gauge_at_zero() {
let mut dev = make_device(VIRTIO_BLK_DEFAULT_CAPACITY_BYTES, DiskThrottle::default());
let c = dev.counters();
assert_eq!(c.currently_throttled_gauge.load(Ordering::Relaxed), 0);
write_reg(&mut dev, VIRTIO_MMIO_STATUS, S_ACK);
write_reg(&mut dev, VIRTIO_MMIO_STATUS, 0);
assert_eq!(
c.currently_throttled_gauge.load(Ordering::Relaxed),
0,
"reset on a non-stalled device must NOT touch the gauge",
);
assert!(
!dev.worker.state().currently_stalled,
"currently_stalled stays false on a non-stalled-device reset",
);
}
#[test]
fn currently_throttled_gauge_initially_zero() {
let dev = make_device(VIRTIO_BLK_DEFAULT_CAPACITY_BYTES, DiskThrottle::default());
let c = dev.counters();
assert_eq!(
c.currently_throttled_gauge.load(Ordering::Relaxed),
0,
"currently_throttled_gauge must initialize to 0",
);
}
#[test]
fn currently_throttled_gauge_inline_redrain_succeeds_decrements_once() {
let mem = make_chain_test_mem();
let mut dev = setup_iops1_drained_chain(&mem);
let mem_ref = dev.mem.get().expect("mem set above");
let outcome1 = {
let WorkerEngine::Inline(engine) = &mut dev.worker.engine;
drain_bracket_impl(
&mut engine.state,
&mut dev.worker.queues,
mem_ref,
&dev.irq_evt,
&dev.interrupt_status,
&dev.device_status,
)
};
assert!(
matches!(
outcome1,
DrainOutcome::ThrottleStalled {
wait_nanos: 1_000_000_000
}
),
"first call must stall with wait_nanos=1_000_000_000 \
(capacity=1, rate=1, deficit=1 → 1s); got {:?}",
outcome1,
);
let c = dev.counters();
assert_eq!(
c.currently_throttled_gauge.load(Ordering::Relaxed),
1,
"first stall must increment gauge to 1",
);
assert!(
dev.worker.state().currently_stalled,
"currently_stalled must be true after first stall",
);
assert_eq!(
c.throttled_count.load(Ordering::Relaxed),
1,
"first stall bumps throttled_count to 1",
);
dev.worker
.state_mut()
.ops_bucket
.set_last_refill_for_test(std::time::Instant::now() - std::time::Duration::from_secs(2));
let outcome2 = {
let WorkerEngine::Inline(engine) = &mut dev.worker.engine;
drain_bracket_impl(
&mut engine.state,
&mut dev.worker.queues,
mem_ref,
&dev.irq_evt,
&dev.interrupt_status,
&dev.device_status,
)
};
assert_eq!(
outcome2,
DrainOutcome::Done,
"second drain (post-refill) must complete; got {:?}",
outcome2,
);
assert_eq!(
c.currently_throttled_gauge.load(Ordering::Relaxed),
0,
"inline re-drain success must dec gauge exactly once: \
1 → 0, not staying at 1, not going negative",
);
assert!(
!dev.worker.state().currently_stalled,
"currently_stalled must clear on retry success",
);
assert_eq!(
c.reads_completed.load(Ordering::Relaxed),
1,
"chain must complete on second drain",
);
assert_eq!(
c.throttled_count.load(Ordering::Relaxed),
1,
"second drain succeeded; throttled_count must NOT bump again",
);
}
#[test]
fn currently_throttled_gauge_inline_redrain_restalls_no_double_count() {
let mem = make_chain_test_mem();
let mut dev = setup_iops1_drained_chain(&mem);
dev.worker
.state_mut()
.ops_bucket
.set_forced_nanos_until_n_tokens_for_test(1_000_000_000);
let mem_ref = dev.mem.get().expect("mem set above");
let outcome1 = {
let WorkerEngine::Inline(engine) = &mut dev.worker.engine;
drain_bracket_impl(
&mut engine.state,
&mut dev.worker.queues,
mem_ref,
&dev.irq_evt,
&dev.interrupt_status,
&dev.device_status,
)
};
assert!(matches!(
outcome1,
DrainOutcome::ThrottleStalled {
wait_nanos: 1_000_000_000
}
));
let c = dev.counters();
assert_eq!(c.currently_throttled_gauge.load(Ordering::Relaxed), 1);
assert!(dev.worker.state().currently_stalled);
assert_eq!(c.throttled_count.load(Ordering::Relaxed), 1);
let outcome2 = {
let WorkerEngine::Inline(engine) = &mut dev.worker.engine;
drain_bracket_impl(
&mut engine.state,
&mut dev.worker.queues,
mem_ref,
&dev.irq_evt,
&dev.interrupt_status,
&dev.device_status,
)
};
assert!(
matches!(
outcome2,
DrainOutcome::ThrottleStalled {
wait_nanos: 1_000_000_000
}
),
"second drain (no refill) must also stall with \
wait_nanos=1_000_000_000; got {:?}",
outcome2,
);
assert_eq!(
c.currently_throttled_gauge.load(Ordering::Relaxed),
1,
"re-stall on same head must NOT double-increment gauge \
(idempotent — gauge is per-request live state, not \
per-event)",
);
assert!(
dev.worker.state().currently_stalled,
"currently_stalled stays true across re-stall",
);
assert_eq!(
c.throttled_count.load(Ordering::Relaxed),
2,
"throttled_count IS per-event; two stall events must \
produce two bumps",
);
assert_eq!(
c.reads_completed.load(Ordering::Relaxed),
0,
"no chain completed; reads_completed must stay 0",
);
}
#[test]
fn inflated_avail_idx_poisons_queue_no_livelock() {
use std::num::Wrapping;
let cap = 4096u64;
let f = make_backed_file_with_pattern(cap, 0xAB);
let mut dev = VirtioBlk::new(f, cap, DiskThrottle::default());
let mem = make_chain_test_mem();
let queue_size: u16 = 16;
let mock = MockSplitQueue::create(&mem, GuestAddress(0), queue_size);
let header_addr = GuestAddress(0x4000);
let data_addr = GuestAddress(0x5000);
let status_addr = GuestAddress(0x6000);
write_blk_header(&mem, header_addr, VIRTIO_BLK_T_IN, 0);
let descs = [
RawDescriptor::from(SplitDescriptor::new(
header_addr.0,
VIRTIO_BLK_OUTHDR_SIZE as u32,
0,
0,
)),
RawDescriptor::from(SplitDescriptor::new(
data_addr.0,
512,
VRING_DESC_F_WRITE as u16,
0,
)),
RawDescriptor::from(SplitDescriptor::new(
status_addr.0,
1,
VRING_DESC_F_WRITE as u16,
0,
)),
];
mock.build_desc_chain(&descs).expect("build chain");
dev.set_mem(mem.clone());
wire_device_to_mock(&mut dev, &mock);
let bad_idx = Wrapping(0u16) + Wrapping(QUEUE_MAX_SIZE) + Wrapping(1u16);
mock.avail().idx().store(u16::to_le(bad_idx.0));
write_reg(&mut dev, VIRTIO_MMIO_QUEUE_NOTIFY, REQ_QUEUE as u32);
let c = dev.counters();
assert_eq!(
c.invalid_avail_idx_count.load(Ordering::Relaxed),
1,
"first hostile drain must bump invalid_avail_idx_count once",
);
assert!(
dev.worker.state().queue_poisoned,
"queue_poisoned must be set after InvalidAvailRingIndex",
);
assert_eq!(c.reads_completed.load(Ordering::Relaxed), 0);
assert_eq!(c.writes_completed.load(Ordering::Relaxed), 0);
assert_eq!(c.throttled_count.load(Ordering::Relaxed), 0);
assert_eq!(c.currently_throttled_gauge.load(Ordering::Relaxed), 0);
for _ in 0..5 {
write_reg(&mut dev, VIRTIO_MMIO_QUEUE_NOTIFY, REQ_QUEUE as u32);
}
assert_eq!(
c.invalid_avail_idx_count.load(Ordering::Relaxed),
1,
"poisoned queue must reject subsequent kicks without re-bumping \
the counter (per-event semantic + flag short-circuit)",
);
assert!(
dev.worker.state().queue_poisoned,
"poison flag stays set across re-kicks",
);
}
#[test]
fn poisoned_queue_clears_on_reset() {
use std::num::Wrapping;
let cap = 4096u64;
let f = make_backed_file_with_pattern(cap, 0xAB);
let mut dev = VirtioBlk::new(f, cap, DiskThrottle::default());
let mem = make_chain_test_mem();
let queue_size: u16 = 16;
let mock = MockSplitQueue::create(&mem, GuestAddress(0), queue_size);
let header_addr = GuestAddress(0x4000);
let data_addr = GuestAddress(0x5000);
let status_addr = GuestAddress(0x6000);
write_blk_header(&mem, header_addr, VIRTIO_BLK_T_IN, 0);
let descs = [
RawDescriptor::from(SplitDescriptor::new(
header_addr.0,
VIRTIO_BLK_OUTHDR_SIZE as u32,
0,
0,
)),
RawDescriptor::from(SplitDescriptor::new(
data_addr.0,
512,
VRING_DESC_F_WRITE as u16,
0,
)),
RawDescriptor::from(SplitDescriptor::new(
status_addr.0,
1,
VRING_DESC_F_WRITE as u16,
0,
)),
];
mock.build_desc_chain(&descs).expect("build chain");
dev.set_mem(mem.clone());
wire_device_to_mock(&mut dev, &mock);
let bad_idx = Wrapping(0u16) + Wrapping(QUEUE_MAX_SIZE) + Wrapping(1u16);
mock.avail().idx().store(u16::to_le(bad_idx.0));
write_reg(&mut dev, VIRTIO_MMIO_QUEUE_NOTIFY, REQ_QUEUE as u32);
assert!(dev.worker.state().queue_poisoned);
let c = dev.counters();
assert_eq!(c.invalid_avail_idx_count.load(Ordering::Relaxed), 1);
write_reg(&mut dev, VIRTIO_MMIO_STATUS, 0);
assert!(
!dev.worker.state().queue_poisoned,
"reset must clear queue_poisoned",
);
assert_eq!(
c.invalid_avail_idx_count.load(Ordering::Relaxed),
1,
"invalid_avail_idx_count is cumulative across resets",
);
let mock2 = MockSplitQueue::create(&mem, GuestAddress(0), queue_size);
let header_addr2 = GuestAddress(0x7000);
let data_addr2 = GuestAddress(0x8000);
let status_addr2 = GuestAddress(0x9000);
write_blk_header(&mem, header_addr2, VIRTIO_BLK_T_IN, 0);
let descs2 = [
RawDescriptor::from(SplitDescriptor::new(
header_addr2.0,
VIRTIO_BLK_OUTHDR_SIZE as u32,
0,
0,
)),
RawDescriptor::from(SplitDescriptor::new(
data_addr2.0,
512,
VRING_DESC_F_WRITE as u16,
0,
)),
RawDescriptor::from(SplitDescriptor::new(
status_addr2.0,
1,
VRING_DESC_F_WRITE as u16,
0,
)),
];
mock2
.build_desc_chain(&descs2)
.expect("build chain after reset");
wire_device_to_mock(&mut dev, &mock2);
write_reg(&mut dev, VIRTIO_MMIO_QUEUE_NOTIFY, REQ_QUEUE as u32);
assert_eq!(c.reads_completed.load(Ordering::Relaxed), 1);
assert_eq!(
c.invalid_avail_idx_count.load(Ordering::Relaxed),
1,
"post-reset legitimate IO must NOT re-trip poison counter",
);
assert!(
!dev.worker.state().queue_poisoned,
"queue stays unpoisoned across legitimate post-reset IO",
);
}
#[test]
fn poisoned_queue_kicks_dont_touch_used_flags() {
use std::num::Wrapping;
let cap = 4096u64;
let f = make_backed_file_with_pattern(cap, 0xAB);
let mut dev = VirtioBlk::new(f, cap, DiskThrottle::default());
let mem = make_chain_test_mem();
let queue_size: u16 = 16;
let mock = MockSplitQueue::create(&mem, GuestAddress(0), queue_size);
let header_addr = GuestAddress(0x4000);
let data_addr = GuestAddress(0x5000);
let status_addr = GuestAddress(0x6000);
write_blk_header(&mem, header_addr, VIRTIO_BLK_T_IN, 0);
let descs = [
RawDescriptor::from(SplitDescriptor::new(
header_addr.0,
VIRTIO_BLK_OUTHDR_SIZE as u32,
0,
0,
)),
RawDescriptor::from(SplitDescriptor::new(
data_addr.0,
512,
VRING_DESC_F_WRITE as u16,
0,
)),
RawDescriptor::from(SplitDescriptor::new(
status_addr.0,
1,
VRING_DESC_F_WRITE as u16,
0,
)),
];
mock.build_desc_chain(&descs).expect("build chain");
dev.set_mem(mem.clone());
wire_device_to_mock(&mut dev, &mock);
let bad_idx = Wrapping(0u16) + Wrapping(QUEUE_MAX_SIZE) + Wrapping(1u16);
mock.avail().idx().store(u16::to_le(bad_idx.0));
write_reg(&mut dev, VIRTIO_MMIO_QUEUE_NOTIFY, REQ_QUEUE as u32);
assert!(dev.worker.state().queue_poisoned);
let used_flags_after_poison: u16 = mem.read_obj(mock.used_addr()).expect("read used.flags");
for _ in 0..5 {
write_reg(&mut dev, VIRTIO_MMIO_QUEUE_NOTIFY, REQ_QUEUE as u32);
let f: u16 = mem
.read_obj(mock.used_addr())
.expect("read used.flags post-kick");
assert_eq!(
f, used_flags_after_poison,
"poisoned queue kicks must not modify used.flags \
(regression: gate moved below disable_notification)",
);
}
let c = dev.counters();
assert_eq!(
c.invalid_avail_idx_count.load(Ordering::Relaxed),
1,
"no additional poison events from re-kicks",
);
}
#[test]
fn currently_throttled_gauge_bytes_only_stall_and_retry() {
let mem = make_chain_test_mem();
let mut dev = setup_bytes_only_drained_chain(&mem, 16, 512);
assert!(
dev.worker.state_mut().ops_bucket.can_consume(1),
"iops bucket must NOT be drained — only bytes is the stall trigger",
);
assert!(
!dev.worker.state_mut().bytes_bucket.can_consume(512),
"bytes bucket must be drained so the chain stalls on bytes alone",
);
let mem_ref = dev.mem.get().expect("mem set above");
let outcome1 = {
let WorkerEngine::Inline(engine) = &mut dev.worker.engine;
drain_bracket_impl(
&mut engine.state,
&mut dev.worker.queues,
mem_ref,
&dev.irq_evt,
&dev.interrupt_status,
&dev.device_status,
)
};
assert!(
matches!(
outcome1,
DrainOutcome::ThrottleStalled {
wait_nanos: 1_000_000_000
}
),
"first call must stall on bytes bucket with \
wait_nanos=1_000_000_000 (capacity=512, rate=512, \
deficit=512); got {:?}",
outcome1,
);
let c = dev.counters();
assert_eq!(
c.currently_throttled_gauge.load(Ordering::Relaxed),
1,
"bytes-only stall must inc gauge 0→1 — gauge transitions \
on stall regardless of which bucket triggered it",
);
assert!(
dev.worker.state().currently_stalled,
"currently_stalled must be true after first stall",
);
assert_eq!(
c.throttled_count.load(Ordering::Relaxed),
1,
"first stall bumps throttled_count to 1",
);
assert_eq!(
c.reads_completed.load(Ordering::Relaxed),
0,
"stalled chain must not have completed",
);
dev.worker
.state_mut()
.bytes_bucket
.set_last_refill_for_test(Instant::now() - Duration::from_secs(2));
let outcome2 = {
let WorkerEngine::Inline(engine) = &mut dev.worker.engine;
drain_bracket_impl(
&mut engine.state,
&mut dev.worker.queues,
mem_ref,
&dev.irq_evt,
&dev.interrupt_status,
&dev.device_status,
)
};
assert_eq!(
outcome2,
DrainOutcome::Done,
"second drain (post bytes-bucket refill) must complete; \
got {:?}",
outcome2,
);
assert_eq!(
c.currently_throttled_gauge.load(Ordering::Relaxed),
0,
"bytes-only retry success must dec gauge exactly once: \
1 → 0, not staying at 1, not going negative",
);
assert!(
!dev.worker.state().currently_stalled,
"currently_stalled must clear on retry success",
);
assert_eq!(
c.reads_completed.load(Ordering::Relaxed),
1,
"chain must complete on second drain",
);
assert_eq!(
c.throttled_count.load(Ordering::Relaxed),
1,
"second drain succeeded; throttled_count must NOT bump again",
);
}
#[tracing_test::traced_test]
#[test]
fn set_mem_twice_emits_warn() {
let mut dev = make_device(VIRTIO_BLK_DEFAULT_CAPACITY_BYTES, DiskThrottle::default());
let mem_a = make_guest_mem(4096);
let mem_b = make_guest_mem(8192);
dev.set_mem(mem_a);
let first_ptr = dev.mem.get().expect("set_mem populated OnceLock") as *const GuestMemoryMmap;
assert!(
!logs_contain("set_mem called on already-initialised"),
"first set_mem must not emit the already-initialised warn",
);
dev.set_mem(mem_b);
assert!(
logs_contain("set_mem called on already-initialised"),
"second set_mem must emit the already-initialised warn so \
a duplicate-bind regression is operator-visible",
);
assert!(
logs_contain("guest memory binding unchanged"),
"warn must explain the no-op semantic — \
'guest memory binding unchanged' tells the operator \
the duplicate call did NOT replace the binding",
);
let after_ptr = dev
.mem
.get()
.expect("OnceLock still populated after second set_mem")
as *const GuestMemoryMmap;
assert_eq!(
first_ptr, after_ptr,
"OnceLock must retain the first GuestMemoryMmap; the \
warn-and-skip path must NOT overwrite the binding on \
the second call",
);
}
#[tracing_test::traced_test]
#[test]
fn features_ok_rejected_with_unadvertised_bit() {
let mut dev = make_device(VIRTIO_BLK_DEFAULT_CAPACITY_BYTES, DiskThrottle::default());
write_reg(&mut dev, VIRTIO_MMIO_STATUS, S_ACK);
write_reg(&mut dev, VIRTIO_MMIO_STATUS, S_DRV);
const VIRTIO_BLK_F_DISCARD: u32 = 13;
assert_eq!(
dev.device_features() & (1u64 << VIRTIO_BLK_F_DISCARD),
0,
"precondition: device must NOT advertise F_DISCARD \
(this test depends on it being unadvertised)",
);
write_reg(&mut dev, VIRTIO_MMIO_DRIVER_FEATURES_SEL, 1);
write_reg(
&mut dev,
VIRTIO_MMIO_DRIVER_FEATURES,
1 << (VIRTIO_F_VERSION_1 - 32),
);
write_reg(&mut dev, VIRTIO_MMIO_DRIVER_FEATURES_SEL, 0);
write_reg(
&mut dev,
VIRTIO_MMIO_DRIVER_FEATURES,
1u32 << VIRTIO_BLK_F_DISCARD,
);
write_reg(&mut dev, VIRTIO_MMIO_STATUS, S_FEAT);
assert_eq!(
dev.device_status.load(Ordering::Acquire),
S_DRV,
"FEATURES_OK must be rejected when driver acked an \
unadvertised feature bit (subset rule violation)",
);
let status = read_reg(&dev, VIRTIO_MMIO_STATUS);
assert_eq!(
status, S_DRV,
"MMIO STATUS read-back must show FEATURES_OK is unset \
after subset-rule rejection",
);
assert!(
logs_contain("unadvertised feature bits"),
"warn must cite 'unadvertised feature bits' so the \
operator can distinguish this rejection branch from \
the version-1 rejection branch",
);
write_reg(&mut dev, VIRTIO_MMIO_DRIVER_FEATURES_SEL, 0);
write_reg(&mut dev, VIRTIO_MMIO_DRIVER_FEATURES, 0);
write_reg(&mut dev, VIRTIO_MMIO_STATUS, S_FEAT);
assert_eq!(
dev.device_status.load(Ordering::Acquire),
S_FEAT,
"FEATURES_OK must be accepted once driver_features is \
a subset of device_features (only VERSION_1 set)",
);
}
#[test]
fn pause_writes_evt_and_resume_clears_paused() {
let dev = make_device(VIRTIO_BLK_DEFAULT_CAPACITY_BYTES, DiskThrottle::default());
assert!(
dev.is_paused(),
"freshly constructed device must start with the V1 \
paused=true sentinel — initial spawn is deferred to \
DRIVER_OK and the rendezvous must pass vacuously \
until the worker actually starts"
);
dev.paused.store(false, Ordering::Release);
assert!(
!dev.is_paused(),
"after the worker clears the sentinel, is_paused() \
must observe the Release-store of false"
);
dev.paused.store(true, Ordering::Release);
assert!(
dev.is_paused(),
"is_paused must observe the worker's Release store"
);
let unparked = dev.resume();
assert!(
!unparked,
"cfg(test) inline engine has no worker thread; resume() returns false"
);
assert!(
!dev.is_paused(),
"resume() must clear the paused flag (Release store)"
);
dev.pause();
let count = dev
.pause_evt
.read()
.expect("pause_evt should be readable after pause()");
assert_eq!(
count, 1,
"pause() must write exactly 1 to pause_evt for a single pause request"
);
dev.pause();
dev.pause();
dev.pause();
let count3 = dev
.pause_evt
.read()
.expect("pause_evt readable after 3 pauses");
assert_eq!(
count3, 3,
"three coalesced pause() calls must accumulate to 3 in counter mode"
);
}
#[test]
fn set_status_preserves_needs_reset_when_already_set() {
use std::sync::atomic::Ordering;
use virtio_bindings::virtio_config::{
VIRTIO_CONFIG_S_ACKNOWLEDGE, VIRTIO_CONFIG_S_NEEDS_RESET,
};
let mut dev = make_device(VIRTIO_BLK_DEFAULT_CAPACITY_BYTES, DiskThrottle::default());
dev.device_status
.fetch_or(VIRTIO_CONFIG_S_NEEDS_RESET, Ordering::SeqCst);
assert_eq!(
dev.device_status.load(Ordering::Acquire),
VIRTIO_CONFIG_S_NEEDS_RESET,
"pre-condition: NEEDS_RESET planted, no FSM bits set"
);
dev.set_status(VIRTIO_CONFIG_S_ACKNOWLEDGE);
let observed = dev.device_status.load(Ordering::Acquire);
assert_ne!(
observed & VIRTIO_CONFIG_S_NEEDS_RESET,
0,
"set_status must NOT clobber NEEDS_RESET via the \
monotone-bit gate path; got device_status={:#x}",
observed,
);
assert_eq!(
observed & VIRTIO_CONFIG_S_ACKNOWLEDGE,
0,
"ACK must NOT be committed when NEEDS_RESET is set — the \
monotone-bit gate rejects the advance; got device_status={:#x}",
observed,
);
}
#[test]
fn set_status_cas_preserves_concurrent_needs_reset() {
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use virtio_bindings::virtio_config::{
VIRTIO_CONFIG_S_ACKNOWLEDGE, VIRTIO_CONFIG_S_NEEDS_RESET,
};
let mut dev = make_device(VIRTIO_BLK_DEFAULT_CAPACITY_BYTES, DiskThrottle::default());
let device_status_handle = Arc::clone(&dev.device_status);
let iters: u32 = 1024;
let stop = Arc::new(AtomicBool::new(false));
let stop_thread = Arc::clone(&stop);
let poison_thread = std::thread::Builder::new()
.name("ktstr-vblk-cas-poison".to_string())
.spawn(move || {
while !stop_thread.load(Ordering::Acquire) {
device_status_handle.fetch_or(VIRTIO_CONFIG_S_NEEDS_RESET, Ordering::SeqCst);
std::thread::yield_now();
}
})
.expect("spawn cas-poison thread");
for _ in 0..iters {
dev.device_status.store(0, Ordering::Release);
std::thread::yield_now();
dev.set_status(VIRTIO_CONFIG_S_ACKNOWLEDGE);
}
dev.device_status
.fetch_or(VIRTIO_CONFIG_S_NEEDS_RESET, Ordering::SeqCst);
dev.set_status(VIRTIO_CONFIG_S_ACKNOWLEDGE);
let final_status = dev.device_status.load(Ordering::Acquire);
assert_ne!(
final_status & VIRTIO_CONFIG_S_NEEDS_RESET,
0,
"final deterministic check: set_status must NOT clobber \
NEEDS_RESET; got device_status={:#x}",
final_status,
);
stop.store(true, Ordering::Release);
poison_thread
.join()
.expect("poison thread should not panic");
}