#![allow(dead_code)]
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Condvar, Mutex};
use crate::fuse::{
build_inval_entry, build_inval_inode, FsBackend, FuseServer, InHeader, MemoryFs, Notifier,
};
use super::queue::Queue;
use super::{ShmRegion, VirtioDevice, VIRTIO_ID_FS};
const FEATURES: u64 = 1u64 << 32;
const TAG_LEN: usize = 36;
pub const DEFAULT_DAX_WINDOW_BYTES: u64 = 8 * 1024 * 1024 * 1024;
#[derive(Clone, Debug)]
pub struct VirtioFsConfig {
pub tag: String,
pub num_request_queues: u32,
pub dax_window_gpa: u64,
pub dax_window_len: u64,
}
pub struct VirtioFs {
core: Arc<FsCore>,
worker: Mutex<Option<std::thread::JoinHandle<()>>>,
}
struct FsCore {
cfg: VirtioFsConfig,
queues: Mutex<Vec<Queue>>,
activated: AtomicBool,
irq_raise: Mutex<Option<Arc<dyn Fn() + Send + Sync>>>,
fuse: Mutex<FuseServer>,
notif_pool: Mutex<Vec<HeldChain>>,
io: Mutex<IoCtl>,
io_cv: Condvar,
}
#[derive(Default)]
struct IoCtl {
kicked: bool,
draining: bool,
stop: bool,
pause_holds: u32,
active_io: u32,
}
struct HeldChain {
head: u16,
chain: Vec<super::queue::Desc>,
}
pub struct FsIoPauseGuard<'a>(&'a [Arc<VirtioFs>]);
impl<'a> FsIoPauseGuard<'a> {
pub fn new(devices: &'a [Arc<VirtioFs>]) -> Self {
for d in devices {
d.pause_io();
}
Self(devices)
}
}
impl Drop for FsIoPauseGuard<'_> {
fn drop(&mut self) {
for d in self.0 {
d.resume_io();
}
}
}
impl VirtioFs {
pub fn new(cfg: VirtioFsConfig) -> Self {
Self::with_backend(cfg, Arc::new(MemoryFs::new()))
}
pub fn with_backend(cfg: VirtioFsConfig, backend: Arc<dyn FsBackend>) -> Self {
assert!(
cfg.dax_window_gpa & 0x3FFF == 0,
"DAX window GPA must be 16 KiB aligned (got {:#x})",
cfg.dax_window_gpa
);
assert!(
cfg.dax_window_len & 0x3FFF == 0,
"DAX window len must be 16 KiB aligned (got {:#x})",
cfg.dax_window_len
);
assert!(
cfg.tag.len() < TAG_LEN,
"tag too long (max {} bytes, got {})",
TAG_LEN - 1,
cfg.tag.len()
);
assert!(
cfg.num_request_queues >= 1,
"must have at least 1 request queue"
);
let core = Arc::new(FsCore {
cfg,
queues: Mutex::new(Vec::new()),
activated: AtomicBool::new(false),
irq_raise: Mutex::new(None),
fuse: Mutex::new(FuseServer::new(backend)),
notif_pool: Mutex::new(Vec::new()),
io: Mutex::new(IoCtl::default()),
io_cv: Condvar::new(),
});
let worker_core = Arc::clone(&core);
let worker = std::thread::Builder::new()
.name("sm-virtiofs-io".into())
.spawn(move || FsCore::worker_loop(worker_core))
.expect("spawn virtio-fs io worker thread");
Self {
core,
worker: Mutex::new(Some(worker)),
}
}
pub fn set_irq_raise(&self, f: Arc<dyn Fn() + Send + Sync>) {
*self.core.irq_raise.lock().unwrap() = Some(f);
}
pub fn reset_for_restore(&self) {
self.core.notif_pool.lock().unwrap().clear();
}
pub fn fuse_server(&self) -> &Mutex<FuseServer> {
&self.core.fuse
}
pub fn pause_io(&self) {
let mut io = self.core.io.lock().unwrap();
io.pause_holds += 1;
while io.active_io > 0 {
io = self.core.io_cv.wait(io).unwrap();
}
}
pub fn resume_io(&self) {
let mut io = self.core.io.lock().unwrap();
io.pause_holds = io.pause_holds.saturating_sub(1);
if io.pause_holds == 0 {
self.core.io_cv.notify_all();
}
}
pub fn wait_io_idle(&self) {
let mut io = self.core.io.lock().unwrap();
while io.kicked || io.draining || io.active_io > 0 {
io = self.core.io_cv.wait(io).unwrap();
}
}
pub fn shutdown_io(&self) {
{
let mut io = self.core.io.lock().unwrap();
io.stop = true;
self.core.io_cv.notify_all();
}
if let Some(h) = self.worker.lock().unwrap().take() {
let _ = h.join();
}
}
}
impl Drop for VirtioFs {
fn drop(&mut self) {
self.shutdown_io();
}
}
impl FsCore {
fn worker_loop(core: Arc<FsCore>) {
let mut io = core.io.lock().unwrap();
loop {
if io.stop {
return;
}
if io.kicked && io.pause_holds == 0 {
io.kicked = false;
io.draining = true;
drop(io);
core.drain_request_queue();
io = core.io.lock().unwrap();
io.draining = false;
core.io_cv.notify_all();
continue;
}
io = core.io_cv.wait(io).unwrap();
}
}
fn kick_worker(&self) {
let mut io = self.io.lock().unwrap();
io.kicked = true;
self.io_cv.notify_all();
}
fn try_enter_io(&self) -> bool {
let mut io = self.io.lock().unwrap();
if io.stop || io.pause_holds > 0 {
return false;
}
io.active_io += 1;
true
}
fn exit_io(&self) {
let mut io = self.io.lock().unwrap();
io.active_io -= 1;
self.io_cv.notify_all();
}
fn raise_irq(&self) {
if let Some(f) = self.irq_raise.lock().unwrap().clone() {
f();
}
}
fn drain_hiprio_queue(&self) {
if !self.activated.load(Ordering::Acquire) {
return;
}
let mut qs = self.queues.lock().unwrap();
let Some(hi) = qs.get_mut(0) else { return };
if !hi.ready {
return;
}
let mut acked_request = false;
while let Some((head, chain)) = hi.pop_chain() {
let has_readable = chain
.iter()
.any(|d| d.flags & super::queue::VRING_DESC_F_WRITE == 0);
if has_readable {
hi.add_used(head, 0);
acked_request = true;
} else {
self.notif_pool
.lock()
.unwrap()
.push(HeldChain { head, chain });
log_notify(|| {
format!(
"hipri stashed: pool size now {}",
self.notif_pool.lock().unwrap().len()
)
});
}
}
drop(qs);
if acked_request {
self.raise_irq();
}
}
fn push_notification(&self, bytes: &[u8]) -> bool {
if !self.activated.load(Ordering::Acquire) {
return false;
}
if !self.try_enter_io() {
return false;
}
let ok = self.push_notification_in_section(bytes);
self.exit_io();
ok
}
fn push_notification_in_section(&self, bytes: &[u8]) -> bool {
let mut qs = self.queues.lock().unwrap();
let held = match self.notif_pool.lock().unwrap().pop() {
Some(h) => h,
None => return false,
};
let Some(q) = qs.get_mut(0) else { return false };
let mut written = 0usize;
for d in held
.chain
.iter()
.filter(|d| d.flags & super::queue::VRING_DESC_F_WRITE != 0)
{
if written >= bytes.len() {
break;
}
let take = (bytes.len() - written).min(d.len as usize);
q.mem.write_slice(d.addr, &bytes[written..written + take]);
written += take;
}
q.add_used(held.head, written as u32);
drop(qs);
self.raise_irq();
true
}
fn drain_request_queue(&self) {
if !self.activated.load(Ordering::Acquire) {
return;
}
loop {
if !self.try_enter_io() {
self.kick_worker();
return;
}
let processed = self.process_one_request();
self.exit_io();
if !processed {
return;
}
}
}
fn process_one_request(&self) -> bool {
const MAX_REQUEST_BYTES: usize = 8 * 1024 * 1024;
let mut qs = self.queues.lock().unwrap();
let q = match qs.get_mut(1) {
Some(q) => q,
None => return false,
};
if !q.ready {
return false;
}
let Some((head, chain)) = q.pop_chain() else {
return false;
};
let in_bytes = match super::queue::read_readable_capped(&chain, &q.mem, MAX_REQUEST_BYTES) {
Some(b) => b,
None => {
eprintln!(
"[virtio-fs] request chain exceeds {MAX_REQUEST_BYTES} byte cap; rejecting"
);
q.add_used(head, 0);
drop(qs);
self.raise_irq();
return true;
}
};
let hdr_size = core::mem::size_of::<InHeader>();
if in_bytes.len() < hdr_size {
eprintln!(
"[virtio-fs] short request: {} bytes < InHeader ({hdr_size})",
in_bytes.len()
);
q.add_used(head, 0);
drop(qs);
self.raise_irq();
return true;
}
let hdr: InHeader =
unsafe { core::ptr::read_unaligned(in_bytes.as_ptr() as *const InHeader) };
let payload = &in_bytes[hdr_size..];
let reply = self.fuse.lock().unwrap().dispatch(&hdr, payload);
let written = super::queue::write_writable(&chain, &q.mem, &reply.bytes);
q.add_used(head, written as u32);
drop(qs);
self.raise_irq();
true
}
}
impl Notifier for VirtioFs {
fn invalidate_inode(&self, nodeid: u64, off: i64, len: i64) {
let bytes = build_inval_inode(nodeid, off, len);
let ok = self.core.push_notification(&bytes);
log_notify(|| {
format!(
"INVAL_INODE nodeid={nodeid} off={off} len={len} pushed={ok} pool_left={}",
self.core.notif_pool.lock().unwrap().len()
)
});
}
fn invalidate_entry(&self, parent_nodeid: u64, name: &[u8]) {
let bytes = build_inval_entry(parent_nodeid, name);
let ok = self.core.push_notification(&bytes);
log_notify(|| {
format!(
"INVAL_ENTRY parent={parent_nodeid} name={:?} pushed={ok} pool_left={}",
std::str::from_utf8(name).unwrap_or("<bytes>"),
self.core.notif_pool.lock().unwrap().len()
)
});
}
}
fn log_notify<F: FnOnce() -> String>(make_msg: F) {
use std::io::Write;
let Some(target) = crate::trace::fuse_target() else {
return;
};
let s = make_msg();
let target = target.to_string_lossy().into_owned();
if target == "1" || target == "stderr" {
eprintln!("[virtio-fs] {s}");
return;
}
if let Ok(mut f) = std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(&target)
{
let _ = writeln!(f, "[virtio-fs] {s}");
}
}
impl VirtioDevice for VirtioFs {
fn device_id(&self) -> u32 {
VIRTIO_ID_FS
}
fn num_queues(&self) -> usize {
1 + self.core.cfg.num_request_queues as usize
}
fn features(&self) -> u64 {
FEATURES
}
fn config(&self) -> Vec<u8> {
let mut buf = vec![0u8; TAG_LEN + 4];
let tag_bytes = self.core.cfg.tag.as_bytes();
let take = tag_bytes.len().min(TAG_LEN - 1);
buf[..take].copy_from_slice(&tag_bytes[..take]);
buf[TAG_LEN..TAG_LEN + 4].copy_from_slice(&self.core.cfg.num_request_queues.to_le_bytes());
buf
}
fn notify(&self, q: u16) {
match q {
0 => self.core.drain_hiprio_queue(),
_ => self.core.kick_worker(),
}
}
fn activate(&self, queues: Vec<Queue>) {
*self.core.queues.lock().unwrap() = queues;
self.core.activated.store(true, Ordering::Release);
eprintln!(
"[virtio-fs] activated: tag={:?} req_queues={} dax_window={:#x}..{:#x} ({} MiB)",
self.core.cfg.tag,
self.core.cfg.num_request_queues,
self.core.cfg.dax_window_gpa,
self.core.cfg.dax_window_gpa + self.core.cfg.dax_window_len,
self.core.cfg.dax_window_len / (1024 * 1024),
);
self.core.kick_worker();
}
fn snapshot_queues(&self) -> Vec<Queue> {
self.core.queues.lock().unwrap().clone()
}
fn shm_regions(&self) -> Vec<ShmRegion> {
if self.core.cfg.dax_window_len == 0 {
return Vec::new();
}
vec![ShmRegion {
id: 0,
gpa: self.core.cfg.dax_window_gpa,
len: self.core.cfg.dax_window_len,
}]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_layout() {
let dev = VirtioFs::new(VirtioFsConfig {
tag: "shared".into(),
num_request_queues: 1,
dax_window_gpa: 0x80_0000_0000,
dax_window_len: DEFAULT_DAX_WINDOW_BYTES,
});
let cfg = dev.config();
assert_eq!(cfg.len(), 40);
assert_eq!(&cfg[..6], b"shared");
for &b in &cfg[6..36] {
assert_eq!(b, 0);
}
assert_eq!(&cfg[36..40], &1u32.to_le_bytes());
}
#[test]
fn shm_region_advertised() {
let dev = VirtioFs::new(VirtioFsConfig {
tag: "t".into(),
num_request_queues: 1,
dax_window_gpa: 0x100_0000_0000,
dax_window_len: 0x4000,
});
let regs = dev.shm_regions();
assert_eq!(regs.len(), 1);
assert_eq!(regs[0].id, 0);
assert_eq!(regs[0].gpa, 0x100_0000_0000);
assert_eq!(regs[0].len, 0x4000);
}
#[test]
fn reset_for_restore_clears_notif_pool() {
let dev = VirtioFs::new(VirtioFsConfig {
tag: "t".into(),
num_request_queues: 1,
dax_window_gpa: 0x100_0000_0000,
dax_window_len: 0x4000,
});
dev.core.notif_pool.lock().unwrap().push(HeldChain {
head: 7,
chain: Vec::new(),
});
assert_eq!(dev.core.notif_pool.lock().unwrap().len(), 1);
dev.reset_for_restore();
assert!(
dev.core.notif_pool.lock().unwrap().is_empty(),
"reset_for_restore must clear stale cross-cycle notification chains"
);
}
use super::super::queue::{GuestMem, VRING_DESC_F_NEXT, VRING_DESC_F_WRITE};
use crate::fuse::{InitIn, OutHeader, FUSE_KERNEL_VERSION};
use proptest::prelude::*;
const MEM_BASE: u64 = 0x10_0000;
const MEM_LEN: usize = 64 * 1024;
const O_DESC: u64 = 0x000;
const O_AVAIL: u64 = 0x800;
const O_USED: u64 = 0x1000;
const O_REQ: u64 = 0x2000;
const O_REPLY: u64 = 0x3000;
const REPLY_CAP: u32 = 4096;
struct DriveOut {
used_idx: u16,
used_id: u32,
written: u32,
reply: Vec<u8>,
}
fn build_request_queue(
mem: &GuestMem,
request: &[u8],
desc_len_override: Option<u32>,
) -> Queue {
let n = request.len().min(MEM_LEN - O_REQ as usize);
mem.write_slice(MEM_BASE + O_REQ, &request[..n]);
let d0 = MEM_BASE + O_DESC;
mem.write_u64(d0, MEM_BASE + O_REQ);
mem.write_u32(d0 + 8, desc_len_override.unwrap_or(request.len() as u32));
mem.write_u16(d0 + 12, VRING_DESC_F_NEXT);
mem.write_u16(d0 + 14, 1);
let d1 = MEM_BASE + O_DESC + 16;
mem.write_u64(d1, MEM_BASE + O_REPLY);
mem.write_u32(d1 + 8, REPLY_CAP);
mem.write_u16(d1 + 12, VRING_DESC_F_WRITE);
mem.write_u16(d1 + 14, 0);
mem.write_u16(MEM_BASE + O_AVAIL + 4, 0);
mem.write_u16(MEM_BASE + O_AVAIL + 2, 1);
let mut req_q = Queue::new(mem.clone());
req_q.size = 8;
req_q.ready = true;
req_q.desc_table = MEM_BASE + O_DESC;
req_q.avail_ring = MEM_BASE + O_AVAIL;
req_q.used_ring = MEM_BASE + O_USED;
req_q
}
fn read_drive_out(mem: &GuestMem) -> DriveOut {
let used_idx = mem.read_u16(MEM_BASE + O_USED + 2);
let used_id = mem.read_u32(MEM_BASE + O_USED + 4);
let written = mem.read_u32(MEM_BASE + O_USED + 8);
let mut reply = vec![0u8; written.min(REPLY_CAP) as usize];
mem.read_slice(MEM_BASE + O_REPLY, &mut reply);
DriveOut {
used_idx,
used_id,
written,
reply,
}
}
fn test_dev() -> VirtioFs {
VirtioFs::new(VirtioFsConfig {
tag: "t".into(),
num_request_queues: 1,
dax_window_gpa: 0x100_0000_0000,
dax_window_len: 0x4000,
})
}
fn drive_request(request: &[u8], desc_len_override: Option<u32>) -> DriveOut {
let mut backing = vec![0u8; MEM_LEN];
let mem = GuestMem::new(backing.as_mut_ptr(), MEM_BASE, MEM_LEN);
let req_q = build_request_queue(&mem, request, desc_len_override);
let dev = test_dev();
{
let mut qs = dev.core.queues.lock().unwrap();
qs.push(Queue::new(mem.clone()));
qs.push(req_q);
}
dev.core.activated.store(true, Ordering::Release);
dev.core.drain_request_queue();
read_drive_out(&mem)
}
fn in_header_bytes(opcode: u32, unique: u64, nodeid: u64, payload_len: usize) -> Vec<u8> {
let hdr = InHeader {
len: (core::mem::size_of::<InHeader>() + payload_len) as u32,
opcode,
unique,
nodeid,
uid: 0,
gid: 0,
pid: 0,
padding: 0,
};
unsafe {
std::slice::from_raw_parts(
&hdr as *const InHeader as *const u8,
core::mem::size_of::<InHeader>(),
)
}
.to_vec()
}
fn valid_init_request() -> Vec<u8> {
let init = InitIn {
major: FUSE_KERNEL_VERSION,
minor: 0,
max_readahead: 0,
flags: 0,
flags2: 0,
unused: [0; 11],
};
let init_bytes = unsafe {
std::slice::from_raw_parts(
&init as *const InitIn as *const u8,
core::mem::size_of::<InitIn>(),
)
};
let mut req = in_header_bytes( 26, 1, 0, init_bytes.len());
req.extend_from_slice(init_bytes);
req
}
#[test]
fn wire_valid_init_round_trips() {
let out = drive_request(&valid_init_request(), None);
assert_eq!(out.used_idx, 1, "chain must be consumed once");
assert_eq!(out.used_id, 0);
assert!(out.written as usize >= core::mem::size_of::<OutHeader>());
let declared = u32::from_le_bytes(out.reply[0..4].try_into().unwrap());
let error = i32::from_le_bytes(out.reply[4..8].try_into().unwrap());
assert_eq!(
declared as u32, out.written,
"reply len must match bytes written"
);
assert_eq!(error, 0, "INIT should succeed");
}
#[test]
fn notify_drains_request_queue_on_the_worker_thread() {
let mut backing = vec![0u8; MEM_LEN];
let mem = GuestMem::new(backing.as_mut_ptr(), MEM_BASE, MEM_LEN);
let req_q = build_request_queue(&mem, &valid_init_request(), None);
mem.write_u16(MEM_BASE + O_AVAIL + 2, 0);
let dev = test_dev();
dev.activate(vec![Queue::new(mem.clone()), req_q]);
dev.wait_io_idle(); assert_eq!(read_drive_out(&mem).used_idx, 0);
mem.write_u16(MEM_BASE + O_AVAIL + 2, 1);
dev.notify(1);
dev.wait_io_idle();
let out = read_drive_out(&mem);
assert_eq!(out.used_idx, 1, "worker must have serviced the request");
let error = i32::from_le_bytes(out.reply[4..8].try_into().unwrap());
assert_eq!(error, 0, "INIT should succeed via the worker path");
}
#[test]
fn activate_kick_services_request_captured_pending_in_avail_ring() {
let mut backing = vec![0u8; MEM_LEN];
let mem = GuestMem::new(backing.as_mut_ptr(), MEM_BASE, MEM_LEN);
let req_q = build_request_queue(&mem, &valid_init_request(), None);
let dev = test_dev();
dev.activate(vec![Queue::new(mem.clone()), req_q]);
dev.wait_io_idle();
let out = read_drive_out(&mem);
assert_eq!(
out.used_idx, 1,
"activate must drain requests captured pending in the snapshot"
);
}
#[test]
fn pause_io_defers_doorbells_and_resume_replays_them() {
let mut backing = vec![0u8; MEM_LEN];
let mem = GuestMem::new(backing.as_mut_ptr(), MEM_BASE, MEM_LEN);
let req_q = build_request_queue(&mem, &valid_init_request(), None);
mem.write_u16(MEM_BASE + O_AVAIL + 2, 0);
let dev = test_dev();
dev.activate(vec![Queue::new(mem.clone()), req_q]);
dev.wait_io_idle();
dev.pause_io();
mem.write_u16(MEM_BASE + O_AVAIL + 2, 1);
dev.notify(1);
std::thread::sleep(std::time::Duration::from_millis(50));
assert_eq!(
read_drive_out(&mem).used_idx,
0,
"no request may be serviced while paused"
);
dev.resume_io();
dev.wait_io_idle();
assert_eq!(
read_drive_out(&mem).used_idx,
1,
"the deferred doorbell must be replayed after resume"
);
}
#[test]
fn push_notification_is_dropped_while_paused() {
let dev = test_dev();
dev.core.activated.store(true, Ordering::Release);
dev.pause_io();
assert!(
!dev.core.push_notification(&[0u8; 8]),
"notification must be dropped during a pause window"
);
dev.resume_io();
}
#[test]
fn shutdown_io_is_idempotent_and_gates_io() {
let dev = test_dev();
dev.shutdown_io();
dev.shutdown_io(); dev.core.activated.store(true, Ordering::Release);
assert!(
!dev.core.push_notification(&[0u8; 8]),
"no guest-RAM writes after shutdown"
);
dev.notify(1); }
#[test]
fn wire_short_request_is_rejected_cleanly() {
let out = drive_request(&[0u8; 8], None);
assert_eq!(out.used_idx, 1);
assert_eq!(out.written, 0, "short request must produce no reply");
}
#[test]
fn wire_oversized_descriptor_is_capped_not_allocated() {
let req = in_header_bytes(26, 1, 0, 0);
let out = drive_request(&req, Some(9 * 1024 * 1024));
assert_eq!(out.used_idx, 1);
assert_eq!(
out.written, 0,
"over-cap chain must be rejected with no reply"
);
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(512))]
#[test]
fn wire_arbitrary_fuse_request_is_safe(
opcode in any::<u32>(),
unique in any::<u64>(),
nodeid in prop_oneof![Just(1u64), any::<u64>()],
payload in proptest::collection::vec(any::<u8>(), 0..1024),
) {
let mut req = in_header_bytes(opcode, unique, nodeid, payload.len());
req.extend_from_slice(&payload);
let out = drive_request(&req, None);
prop_assert_eq!(out.used_idx, 1);
prop_assert_eq!(out.used_id, 0);
prop_assert!(out.written <= REPLY_CAP, "wrote past writable descriptor");
prop_assert!(
out.written as usize >= core::mem::size_of::<OutHeader>(),
"reply smaller than OutHeader ({} bytes)",
out.written,
);
let declared = u32::from_le_bytes(out.reply[0..4].try_into().unwrap());
prop_assert!(
declared >= out.written,
"OutHeader.len {} undershoots written {}",
declared, out.written,
);
}
}
}