#![allow(dead_code)]
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use crate::fuse::{
build_inval_entry, build_inval_inode, FsBackend, FuseServer, InHeader, MemoryFs, Notifier,
};
use super::queue::Queue;
use super::{ShmRegion, VirtioDevice, VIRTIO_ID_FS};
const FEATURES: u64 = 1u64 << 32;
const TAG_LEN: usize = 36;
pub const DEFAULT_DAX_WINDOW_BYTES: u64 = 8 * 1024 * 1024 * 1024;
#[derive(Clone, Debug)]
pub struct VirtioFsConfig {
pub tag: String,
pub num_request_queues: u32,
pub dax_window_gpa: u64,
pub dax_window_len: u64,
}
pub struct VirtioFs {
cfg: VirtioFsConfig,
queues: Mutex<Vec<Queue>>,
activated: AtomicBool,
irq_raise: Mutex<Option<Arc<dyn Fn() + Send + Sync>>>,
fuse: Mutex<FuseServer>,
notif_pool: Mutex<Vec<HeldChain>>,
}
struct HeldChain {
head: u16,
chain: Vec<super::queue::Desc>,
}
impl VirtioFs {
pub fn new(cfg: VirtioFsConfig) -> Self {
Self::with_backend(cfg, Arc::new(MemoryFs::new()))
}
pub fn with_backend(cfg: VirtioFsConfig, backend: Arc<dyn FsBackend>) -> Self {
assert!(
cfg.dax_window_gpa & 0x3FFF == 0,
"DAX window GPA must be 16 KiB aligned (got {:#x})",
cfg.dax_window_gpa
);
assert!(
cfg.dax_window_len & 0x3FFF == 0,
"DAX window len must be 16 KiB aligned (got {:#x})",
cfg.dax_window_len
);
assert!(
cfg.tag.len() < TAG_LEN,
"tag too long (max {} bytes, got {})",
TAG_LEN - 1,
cfg.tag.len()
);
assert!(
cfg.num_request_queues >= 1,
"must have at least 1 request queue"
);
Self {
cfg,
queues: Mutex::new(Vec::new()),
activated: AtomicBool::new(false),
irq_raise: Mutex::new(None),
fuse: Mutex::new(FuseServer::new(backend)),
notif_pool: Mutex::new(Vec::new()),
}
}
pub fn set_irq_raise(&self, f: Arc<dyn Fn() + Send + Sync>) {
*self.irq_raise.lock().unwrap() = Some(f);
}
pub fn reset_for_restore(&self) {
self.notif_pool.lock().unwrap().clear();
}
pub fn fuse_server(&self) -> &Mutex<FuseServer> {
&self.fuse
}
fn drain_hiprio_queue(&self) {
if !self.activated.load(Ordering::Acquire) {
return;
}
let mut qs = self.queues.lock().unwrap();
let Some(hi) = qs.get_mut(0) else { return };
if !hi.ready {
return;
}
let mut acked_request = false;
while let Some((head, chain)) = hi.pop_chain() {
let has_readable = chain
.iter()
.any(|d| d.flags & super::queue::VRING_DESC_F_WRITE == 0);
if has_readable {
hi.add_used(head, 0);
acked_request = true;
} else {
self.notif_pool
.lock()
.unwrap()
.push(HeldChain { head, chain });
log_notify(|| {
format!(
"hipri stashed: pool size now {}",
self.notif_pool.lock().unwrap().len()
)
});
}
}
drop(qs);
if acked_request {
if let Some(f) = self.irq_raise.lock().unwrap().clone() {
f();
}
}
}
fn push_notification(&self, bytes: &[u8]) -> bool {
if !self.activated.load(Ordering::Acquire) {
return false;
}
let mut qs = self.queues.lock().unwrap();
let held = match self.notif_pool.lock().unwrap().pop() {
Some(h) => h,
None => return false,
};
let Some(q) = qs.get_mut(0) else { return false };
let mut written = 0usize;
for d in held
.chain
.iter()
.filter(|d| d.flags & super::queue::VRING_DESC_F_WRITE != 0)
{
if written >= bytes.len() {
break;
}
let take = (bytes.len() - written).min(d.len as usize);
q.mem.write_slice(d.addr, &bytes[written..written + take]);
written += take;
}
q.add_used(held.head, written as u32);
drop(qs);
if let Some(f) = self.irq_raise.lock().unwrap().clone() {
f();
}
true
}
fn drain_request_queue(&self) {
if !self.activated.load(Ordering::Acquire) {
return;
}
let mut qs = self.queues.lock().unwrap();
let q = match qs.get_mut(1) {
Some(q) => q,
None => return,
};
if !q.ready {
return;
}
let mut any = false;
const MAX_REQUEST_BYTES: usize = 8 * 1024 * 1024;
while let Some((head, chain)) = q.pop_chain() {
let in_bytes =
match super::queue::read_readable_capped(&chain, &q.mem, MAX_REQUEST_BYTES) {
Some(b) => b,
None => {
eprintln!(
"[virtio-fs] request chain exceeds {MAX_REQUEST_BYTES} byte cap; rejecting"
);
q.add_used(head, 0);
any = true;
continue;
}
};
let hdr_size = core::mem::size_of::<InHeader>();
if in_bytes.len() < hdr_size {
eprintln!(
"[virtio-fs] short request: {} bytes < InHeader ({hdr_size})",
in_bytes.len()
);
q.add_used(head, 0);
any = true;
continue;
}
let hdr: InHeader =
unsafe { core::ptr::read_unaligned(in_bytes.as_ptr() as *const InHeader) };
let payload = &in_bytes[hdr_size..];
let reply = self.fuse.lock().unwrap().dispatch(&hdr, payload);
let written = super::queue::write_writable(&chain, &q.mem, &reply.bytes);
q.add_used(head, written as u32);
any = true;
}
drop(qs);
if any {
if let Some(f) = self.irq_raise.lock().unwrap().clone() {
f();
}
}
}
}
impl Notifier for VirtioFs {
fn invalidate_inode(&self, nodeid: u64, off: i64, len: i64) {
let bytes = build_inval_inode(nodeid, off, len);
let ok = self.push_notification(&bytes);
log_notify(|| {
format!(
"INVAL_INODE nodeid={nodeid} off={off} len={len} pushed={ok} pool_left={}",
self.notif_pool.lock().unwrap().len()
)
});
}
fn invalidate_entry(&self, parent_nodeid: u64, name: &[u8]) {
let bytes = build_inval_entry(parent_nodeid, name);
let ok = self.push_notification(&bytes);
log_notify(|| {
format!(
"INVAL_ENTRY parent={parent_nodeid} name={:?} pushed={ok} pool_left={}",
std::str::from_utf8(name).unwrap_or("<bytes>"),
self.notif_pool.lock().unwrap().len()
)
});
}
}
fn log_notify<F: FnOnce() -> String>(make_msg: F) {
use std::io::Write;
let Some(target) = crate::trace::fuse_target() else {
return;
};
let s = make_msg();
let target = target.to_string_lossy().into_owned();
if target == "1" || target == "stderr" {
eprintln!("[virtio-fs] {s}");
return;
}
if let Ok(mut f) = std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(&target)
{
let _ = writeln!(f, "[virtio-fs] {s}");
}
}
impl VirtioDevice for VirtioFs {
fn device_id(&self) -> u32 {
VIRTIO_ID_FS
}
fn num_queues(&self) -> usize {
1 + self.cfg.num_request_queues as usize
}
fn features(&self) -> u64 {
FEATURES
}
fn config(&self) -> Vec<u8> {
let mut buf = vec![0u8; TAG_LEN + 4];
let tag_bytes = self.cfg.tag.as_bytes();
let take = tag_bytes.len().min(TAG_LEN - 1);
buf[..take].copy_from_slice(&tag_bytes[..take]);
buf[TAG_LEN..TAG_LEN + 4].copy_from_slice(&self.cfg.num_request_queues.to_le_bytes());
buf
}
fn notify(&self, q: u16) {
match q {
0 => self.drain_hiprio_queue(),
_ => self.drain_request_queue(),
}
}
fn activate(&self, queues: Vec<Queue>) {
*self.queues.lock().unwrap() = queues;
self.activated.store(true, Ordering::Release);
eprintln!(
"[virtio-fs] activated: tag={:?} req_queues={} dax_window={:#x}..{:#x} ({} MiB)",
self.cfg.tag,
self.cfg.num_request_queues,
self.cfg.dax_window_gpa,
self.cfg.dax_window_gpa + self.cfg.dax_window_len,
self.cfg.dax_window_len / (1024 * 1024),
);
}
fn snapshot_queues(&self) -> Vec<Queue> {
self.queues.lock().unwrap().clone()
}
fn shm_regions(&self) -> Vec<ShmRegion> {
if self.cfg.dax_window_len == 0 {
return Vec::new();
}
vec![ShmRegion {
id: 0,
gpa: self.cfg.dax_window_gpa,
len: self.cfg.dax_window_len,
}]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_layout() {
let dev = VirtioFs::new(VirtioFsConfig {
tag: "shared".into(),
num_request_queues: 1,
dax_window_gpa: 0x80_0000_0000,
dax_window_len: DEFAULT_DAX_WINDOW_BYTES,
});
let cfg = dev.config();
assert_eq!(cfg.len(), 40);
assert_eq!(&cfg[..6], b"shared");
for &b in &cfg[6..36] {
assert_eq!(b, 0);
}
assert_eq!(&cfg[36..40], &1u32.to_le_bytes());
}
#[test]
fn shm_region_advertised() {
let dev = VirtioFs::new(VirtioFsConfig {
tag: "t".into(),
num_request_queues: 1,
dax_window_gpa: 0x100_0000_0000,
dax_window_len: 0x4000,
});
let regs = dev.shm_regions();
assert_eq!(regs.len(), 1);
assert_eq!(regs[0].id, 0);
assert_eq!(regs[0].gpa, 0x100_0000_0000);
assert_eq!(regs[0].len, 0x4000);
}
#[test]
fn reset_for_restore_clears_notif_pool() {
let dev = VirtioFs::new(VirtioFsConfig {
tag: "t".into(),
num_request_queues: 1,
dax_window_gpa: 0x100_0000_0000,
dax_window_len: 0x4000,
});
dev.notif_pool.lock().unwrap().push(HeldChain {
head: 7,
chain: Vec::new(),
});
assert_eq!(dev.notif_pool.lock().unwrap().len(), 1);
dev.reset_for_restore();
assert!(
dev.notif_pool.lock().unwrap().is_empty(),
"reset_for_restore must clear stale cross-cycle notification chains"
);
}
use super::super::queue::{GuestMem, VRING_DESC_F_NEXT, VRING_DESC_F_WRITE};
use crate::fuse::{InitIn, OutHeader, FUSE_KERNEL_VERSION};
use proptest::prelude::*;
const MEM_BASE: u64 = 0x10_0000;
const MEM_LEN: usize = 64 * 1024;
const O_DESC: u64 = 0x000;
const O_AVAIL: u64 = 0x800;
const O_USED: u64 = 0x1000;
const O_REQ: u64 = 0x2000;
const O_REPLY: u64 = 0x3000;
const REPLY_CAP: u32 = 4096;
struct DriveOut {
used_idx: u16,
used_id: u32,
written: u32,
reply: Vec<u8>,
}
fn drive_request(request: &[u8], desc_len_override: Option<u32>) -> DriveOut {
let mut backing = vec![0u8; MEM_LEN];
let mem = GuestMem::new(backing.as_mut_ptr(), MEM_BASE, MEM_LEN);
let n = request.len().min(MEM_LEN - O_REQ as usize);
mem.write_slice(MEM_BASE + O_REQ, &request[..n]);
let d0 = MEM_BASE + O_DESC;
mem.write_u64(d0, MEM_BASE + O_REQ);
mem.write_u32(d0 + 8, desc_len_override.unwrap_or(request.len() as u32));
mem.write_u16(d0 + 12, VRING_DESC_F_NEXT);
mem.write_u16(d0 + 14, 1);
let d1 = MEM_BASE + O_DESC + 16;
mem.write_u64(d1, MEM_BASE + O_REPLY);
mem.write_u32(d1 + 8, REPLY_CAP);
mem.write_u16(d1 + 12, VRING_DESC_F_WRITE);
mem.write_u16(d1 + 14, 0);
mem.write_u16(MEM_BASE + O_AVAIL + 4, 0);
mem.write_u16(MEM_BASE + O_AVAIL + 2, 1);
let mut req_q = Queue::new(mem.clone());
req_q.size = 8;
req_q.ready = true;
req_q.desc_table = MEM_BASE + O_DESC;
req_q.avail_ring = MEM_BASE + O_AVAIL;
req_q.used_ring = MEM_BASE + O_USED;
let dev = VirtioFs::new(VirtioFsConfig {
tag: "t".into(),
num_request_queues: 1,
dax_window_gpa: 0x100_0000_0000,
dax_window_len: 0x4000,
});
{
let mut qs = dev.queues.lock().unwrap();
qs.push(Queue::new(mem.clone()));
qs.push(req_q);
}
dev.activated.store(true, Ordering::Release);
dev.drain_request_queue();
let used_idx = mem.read_u16(MEM_BASE + O_USED + 2);
let used_id = mem.read_u32(MEM_BASE + O_USED + 4);
let written = mem.read_u32(MEM_BASE + O_USED + 8);
let mut reply = vec![0u8; written.min(REPLY_CAP) as usize];
mem.read_slice(MEM_BASE + O_REPLY, &mut reply);
DriveOut {
used_idx,
used_id,
written,
reply,
}
}
fn in_header_bytes(opcode: u32, unique: u64, nodeid: u64, payload_len: usize) -> Vec<u8> {
let hdr = InHeader {
len: (core::mem::size_of::<InHeader>() + payload_len) as u32,
opcode,
unique,
nodeid,
uid: 0,
gid: 0,
pid: 0,
padding: 0,
};
unsafe {
std::slice::from_raw_parts(
&hdr as *const InHeader as *const u8,
core::mem::size_of::<InHeader>(),
)
}
.to_vec()
}
#[test]
fn wire_valid_init_round_trips() {
let init = InitIn {
major: FUSE_KERNEL_VERSION,
minor: 0,
max_readahead: 0,
flags: 0,
flags2: 0,
unused: [0; 11],
};
let init_bytes = unsafe {
std::slice::from_raw_parts(
&init as *const InitIn as *const u8,
core::mem::size_of::<InitIn>(),
)
};
let mut req = in_header_bytes( 26, 1, 0, init_bytes.len());
req.extend_from_slice(init_bytes);
let out = drive_request(&req, None);
assert_eq!(out.used_idx, 1, "chain must be consumed once");
assert_eq!(out.used_id, 0);
assert!(out.written as usize >= core::mem::size_of::<OutHeader>());
let declared = u32::from_le_bytes(out.reply[0..4].try_into().unwrap());
let error = i32::from_le_bytes(out.reply[4..8].try_into().unwrap());
assert_eq!(
declared as u32, out.written,
"reply len must match bytes written"
);
assert_eq!(error, 0, "INIT should succeed");
}
#[test]
fn wire_short_request_is_rejected_cleanly() {
let out = drive_request(&[0u8; 8], None);
assert_eq!(out.used_idx, 1);
assert_eq!(out.written, 0, "short request must produce no reply");
}
#[test]
fn wire_oversized_descriptor_is_capped_not_allocated() {
let req = in_header_bytes(26, 1, 0, 0);
let out = drive_request(&req, Some(9 * 1024 * 1024));
assert_eq!(out.used_idx, 1);
assert_eq!(
out.written, 0,
"over-cap chain must be rejected with no reply"
);
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(512))]
#[test]
fn wire_arbitrary_fuse_request_is_safe(
opcode in any::<u32>(),
unique in any::<u64>(),
nodeid in prop_oneof![Just(1u64), any::<u64>()],
payload in proptest::collection::vec(any::<u8>(), 0..1024),
) {
let mut req = in_header_bytes(opcode, unique, nodeid, payload.len());
req.extend_from_slice(&payload);
let out = drive_request(&req, None);
prop_assert_eq!(out.used_idx, 1);
prop_assert_eq!(out.used_id, 0);
prop_assert!(out.written <= REPLY_CAP, "wrote past writable descriptor");
prop_assert!(
out.written as usize >= core::mem::size_of::<OutHeader>(),
"reply smaller than OutHeader ({} bytes)",
out.written,
);
let declared = u32::from_le_bytes(out.reply[0..4].try_into().unwrap());
prop_assert!(
declared >= out.written,
"OutHeader.len {} undershoots written {}",
declared, out.written,
);
}
}
}