#![allow(dead_code)]
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use crate::fuse::{build_inval_entry, build_inval_inode, FsBackend, FuseServer, InHeader, MemoryFs, Notifier};
use super::queue::Queue;
use super::{ShmRegion, VirtioDevice, VIRTIO_ID_FS};
const FEATURES: u64 = 1u64 << 32;
const TAG_LEN: usize = 36;
pub const DEFAULT_DAX_WINDOW_BYTES: u64 = 8 * 1024 * 1024 * 1024;
#[derive(Clone, Debug)]
pub struct VirtioFsConfig {
pub tag: String,
pub num_request_queues: u32,
pub dax_window_gpa: u64,
pub dax_window_len: u64,
}
pub struct VirtioFs {
cfg: VirtioFsConfig,
queues: Mutex<Vec<Queue>>,
activated: AtomicBool,
irq_raise: Mutex<Option<Arc<dyn Fn() + Send + Sync>>>,
fuse: Mutex<FuseServer>,
notif_pool: Mutex<Vec<HeldChain>>,
}
struct HeldChain {
head: u16,
chain: Vec<super::queue::Desc>,
}
impl VirtioFs {
pub fn new(cfg: VirtioFsConfig) -> Self {
Self::with_backend(cfg, Arc::new(MemoryFs::new()))
}
pub fn with_backend(cfg: VirtioFsConfig, backend: Arc<dyn FsBackend>) -> Self {
assert!(
cfg.dax_window_gpa & 0x3FFF == 0,
"DAX window GPA must be 16 KiB aligned (got {:#x})",
cfg.dax_window_gpa
);
assert!(
cfg.dax_window_len & 0x3FFF == 0,
"DAX window len must be 16 KiB aligned (got {:#x})",
cfg.dax_window_len
);
assert!(
cfg.tag.len() < TAG_LEN,
"tag too long (max {} bytes, got {})",
TAG_LEN - 1,
cfg.tag.len()
);
assert!(
cfg.num_request_queues >= 1,
"must have at least 1 request queue"
);
Self {
cfg,
queues: Mutex::new(Vec::new()),
activated: AtomicBool::new(false),
irq_raise: Mutex::new(None),
fuse: Mutex::new(FuseServer::new(backend)),
notif_pool: Mutex::new(Vec::new()),
}
}
pub fn set_irq_raise(&self, f: Arc<dyn Fn() + Send + Sync>) {
*self.irq_raise.lock().unwrap() = Some(f);
}
pub fn fuse_server(&self) -> &Mutex<FuseServer> {
&self.fuse
}
fn drain_hiprio_queue(&self) {
if !self.activated.load(Ordering::Acquire) {
return;
}
let mut qs = self.queues.lock().unwrap();
let Some(hi) = qs.get_mut(0) else { return };
if !hi.ready {
return;
}
let mut acked_request = false;
while let Some((head, chain)) = hi.pop_chain() {
let has_readable = chain
.iter()
.any(|d| d.flags & super::queue::VRING_DESC_F_WRITE == 0);
if has_readable {
hi.add_used(head, 0);
acked_request = true;
} else {
self.notif_pool.lock().unwrap().push(HeldChain { head, chain });
log_notify(|| format!(
"hipri stashed: pool size now {}",
self.notif_pool.lock().unwrap().len()
));
}
}
drop(qs);
if acked_request {
if let Some(f) = self.irq_raise.lock().unwrap().clone() {
f();
}
}
}
fn push_notification(&self, bytes: &[u8]) -> bool {
if !self.activated.load(Ordering::Acquire) {
return false;
}
let held = self.notif_pool.lock().unwrap().pop();
let Some(held) = held else { return false };
let mut qs = self.queues.lock().unwrap();
let Some(q) = qs.get_mut(0) else { return false };
let mut written = 0usize;
for d in held
.chain
.iter()
.filter(|d| d.flags & super::queue::VRING_DESC_F_WRITE != 0)
{
if written >= bytes.len() {
break;
}
let take = (bytes.len() - written).min(d.len as usize);
q.mem.write_slice(d.addr, &bytes[written..written + take]);
written += take;
}
q.add_used(held.head, written as u32);
drop(qs);
if let Some(f) = self.irq_raise.lock().unwrap().clone() {
f();
}
true
}
fn drain_request_queue(&self) {
if !self.activated.load(Ordering::Acquire) {
return;
}
let mut qs = self.queues.lock().unwrap();
let q = match qs.get_mut(1) {
Some(q) => q,
None => return,
};
if !q.ready {
return;
}
let mut any = false;
const MAX_REQUEST_BYTES: usize = 8 * 1024 * 1024;
while let Some((head, chain)) = q.pop_chain() {
let mut in_bytes: Vec<u8> = Vec::new();
let mut over_cap = false;
for d in chain.iter().filter(|d| d.flags & super::queue::VRING_DESC_F_WRITE == 0) {
let want = d.len as usize;
if in_bytes.len().saturating_add(want) > MAX_REQUEST_BYTES {
eprintln!(
"[virtio-fs] request chain exceeds {} byte cap; rejecting",
MAX_REQUEST_BYTES
);
over_cap = true;
break;
}
let off = in_bytes.len();
in_bytes.resize(off + want, 0);
q.mem.read_slice(d.addr, &mut in_bytes[off..]);
}
if over_cap {
q.add_used(head, 0);
any = true;
continue;
}
let hdr_size = core::mem::size_of::<InHeader>();
if in_bytes.len() < hdr_size {
eprintln!(
"[virtio-fs] short request: {} bytes < InHeader ({hdr_size})",
in_bytes.len()
);
q.add_used(head, 0);
any = true;
continue;
}
let hdr: InHeader = unsafe { core::ptr::read_unaligned(in_bytes.as_ptr() as *const InHeader) };
let payload = &in_bytes[hdr_size..];
let reply = self.fuse.lock().unwrap().dispatch(&hdr, payload);
let mut written = 0usize;
for d in chain.iter().filter(|d| d.flags & super::queue::VRING_DESC_F_WRITE != 0) {
if written >= reply.bytes.len() {
break;
}
let take = (reply.bytes.len() - written).min(d.len as usize);
q.mem.write_slice(d.addr, &reply.bytes[written..written + take]);
written += take;
}
q.add_used(head, written as u32);
any = true;
}
drop(qs);
if any {
if let Some(f) = self.irq_raise.lock().unwrap().clone() {
f();
}
}
}
}
impl Notifier for VirtioFs {
fn invalidate_inode(&self, nodeid: u64, off: i64, len: i64) {
let bytes = build_inval_inode(nodeid, off, len);
let ok = self.push_notification(&bytes);
log_notify(|| format!(
"INVAL_INODE nodeid={nodeid} off={off} len={len} pushed={ok} pool_left={}",
self.notif_pool.lock().unwrap().len()
));
}
fn invalidate_entry(&self, parent_nodeid: u64, name: &[u8]) {
let bytes = build_inval_entry(parent_nodeid, name);
let ok = self.push_notification(&bytes);
log_notify(|| format!(
"INVAL_ENTRY parent={parent_nodeid} name={:?} pushed={ok} pool_left={}",
std::str::from_utf8(name).unwrap_or("<bytes>"),
self.notif_pool.lock().unwrap().len()
));
}
}
fn log_notify<F: FnOnce() -> String>(make_msg: F) {
use std::io::Write;
let Some(target) = std::env::var_os("SUPERMACHINE_FUSE_TRACE") else { return };
let s = make_msg();
let target = target.to_string_lossy().into_owned();
if target == "1" || target == "stderr" {
eprintln!("[virtio-fs] {s}");
return;
}
if let Ok(mut f) = std::fs::OpenOptions::new().create(true).append(true).open(&target) {
let _ = writeln!(f, "[virtio-fs] {s}");
}
}
impl VirtioDevice for VirtioFs {
fn device_id(&self) -> u32 {
VIRTIO_ID_FS
}
fn num_queues(&self) -> usize {
1 + self.cfg.num_request_queues as usize
}
fn features(&self) -> u64 {
FEATURES
}
fn config(&self) -> Vec<u8> {
let mut buf = vec![0u8; TAG_LEN + 4];
let tag_bytes = self.cfg.tag.as_bytes();
let take = tag_bytes.len().min(TAG_LEN - 1);
buf[..take].copy_from_slice(&tag_bytes[..take]);
buf[TAG_LEN..TAG_LEN + 4].copy_from_slice(&self.cfg.num_request_queues.to_le_bytes());
buf
}
fn notify(&self, q: u16) {
match q {
0 => self.drain_hiprio_queue(),
_ => self.drain_request_queue(),
}
}
fn activate(&self, queues: Vec<Queue>) {
*self.queues.lock().unwrap() = queues;
self.activated.store(true, Ordering::Release);
eprintln!(
"[virtio-fs] activated: tag={:?} req_queues={} dax_window={:#x}..{:#x} ({} MiB)",
self.cfg.tag,
self.cfg.num_request_queues,
self.cfg.dax_window_gpa,
self.cfg.dax_window_gpa + self.cfg.dax_window_len,
self.cfg.dax_window_len / (1024 * 1024),
);
}
fn snapshot_queues(&self) -> Vec<Queue> {
self.queues.lock().unwrap().clone()
}
fn shm_regions(&self) -> Vec<ShmRegion> {
vec![ShmRegion {
id: 0,
gpa: self.cfg.dax_window_gpa,
len: self.cfg.dax_window_len,
}]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_layout() {
let dev = VirtioFs::new(VirtioFsConfig {
tag: "shared".into(),
num_request_queues: 1,
dax_window_gpa: 0x80_0000_0000,
dax_window_len: DEFAULT_DAX_WINDOW_BYTES,
});
let cfg = dev.config();
assert_eq!(cfg.len(), 40);
assert_eq!(&cfg[..6], b"shared");
for &b in &cfg[6..36] {
assert_eq!(b, 0);
}
assert_eq!(&cfg[36..40], &1u32.to_le_bytes());
}
#[test]
fn shm_region_advertised() {
let dev = VirtioFs::new(VirtioFsConfig {
tag: "t".into(),
num_request_queues: 1,
dax_window_gpa: 0x100_0000_0000,
dax_window_len: 0x4000,
});
let regs = dev.shm_regions();
assert_eq!(regs.len(), 1);
assert_eq!(regs[0].id, 0);
assert_eq!(regs[0].gpa, 0x100_0000_0000);
assert_eq!(regs[0].len, 0x4000);
}
}