use std::fs::{File, OpenOptions};
use std::os::unix::io::AsRawFd;
use std::sync::{Arc, Mutex};
use super::queue::Queue;
use super::{VirtioDevice, VIRTIO_ID_BLOCK};
const VIRTIO_BLK_T_IN: u32 = 0;
const VIRTIO_BLK_T_OUT: u32 = 1;
const VIRTIO_BLK_T_FLUSH: u32 = 4;
const VIRTIO_BLK_T_GET_ID: u32 = 8;
const VIRTIO_BLK_S_OK: u8 = 0;
const VIRTIO_BLK_S_IOERR: u8 = 1;
const VIRTIO_BLK_S_UNSUPP: u8 = 2;
const VIRTIO_BLK_F_SIZE_MAX: u64 = 1 << 1;
const VIRTIO_BLK_F_SEG_MAX: u64 = 1 << 2;
const VIRTIO_BLK_F_RO: u64 = 1 << 5;
const VIRTIO_BLK_F_BLK_SIZE: u64 = 1 << 6;
const VIRTIO_BLK_F_FLUSH: u64 = 1 << 9;
const VIRTIO_F_VERSION_1: u64 = 1 << 32;
const SECTOR_SIZE: u64 = 512;
pub struct VirtioBlk {
name: String,
backing_ptr: *mut u8,
backing_len: usize,
writable: bool,
queues: Mutex<Vec<Queue>>,
activated: std::sync::atomic::AtomicBool,
irq_raise: Mutex<Option<Arc<dyn Fn() + Send + Sync>>>,
}
unsafe impl Send for VirtioBlk {}
unsafe impl Sync for VirtioBlk {}
impl VirtioBlk {
pub fn open_ro(name: &str, path: &str) -> std::io::Result<Self> {
let f = File::open(path)?;
let len = f.metadata()?.len() as usize;
let p = unsafe {
libc::mmap(
std::ptr::null_mut(),
len,
libc::PROT_READ,
libc::MAP_PRIVATE,
f.as_raw_fd(),
0,
)
};
if p == libc::MAP_FAILED {
return Err(std::io::Error::last_os_error());
}
unsafe {
libc::madvise(p, len, libc::MADV_SEQUENTIAL);
}
eprintln!("[virtio-blk:{name}] mmap ro {} bytes from {path}", len);
drop(f);
Ok(Self {
name: name.to_string(),
backing_ptr: p as *mut u8,
backing_len: len,
writable: false,
queues: Mutex::new(Vec::new()),
activated: std::sync::atomic::AtomicBool::new(false),
irq_raise: Mutex::new(None),
})
}
pub fn open_rw(name: &str, path: &str, size_bytes: u64) -> std::io::Result<Self> {
let f = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(path)?;
let cur_len = f.metadata()?.len();
if cur_len < size_bytes {
f.set_len(size_bytes)?;
}
let len = f.metadata()?.len() as usize;
let p = unsafe {
libc::mmap(
std::ptr::null_mut(),
len,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_SHARED,
f.as_raw_fd(),
0,
)
};
if p == libc::MAP_FAILED {
return Err(std::io::Error::last_os_error());
}
unsafe {
libc::madvise(p, len, libc::MADV_RANDOM);
}
eprintln!("[virtio-blk:{name}] mmap rw {} bytes from {path}", len);
drop(f);
Ok(Self {
name: name.to_string(),
backing_ptr: p as *mut u8,
backing_len: len,
writable: true,
queues: Mutex::new(Vec::new()),
activated: std::sync::atomic::AtomicBool::new(false),
irq_raise: Mutex::new(None),
})
}
pub fn set_irq_raise(&self, f: Arc<dyn Fn() + Send + Sync>) {
*self.irq_raise.lock().unwrap() = Some(f);
}
fn drain_q(&self) {
let mut qs = self.queues.lock().unwrap();
let q = match qs.get_mut(0) {
Some(q) => q,
None => return,
};
if !q.ready {
return;
}
let mut any_used = false;
while let Some((head, chain)) = q.pop_chain() {
if chain.len() < 2 {
q.add_used(head, 0);
any_used = true;
continue;
}
let hdr = chain[0];
let status_desc = chain[chain.len() - 1];
let req_type = q.mem.read_u32(hdr.addr);
let _reserved = q.mem.read_u32(hdr.addr + 4);
let sector = q.mem.read_u64(hdr.addr + 8);
let mut status = VIRTIO_BLK_S_OK;
let mut bytes_written: u32 = 1; match req_type {
VIRTIO_BLK_T_IN => {
let mut off = sector.checked_mul(SECTOR_SIZE);
for d in &chain[1..chain.len() - 1] {
let want = d.len as u64;
let Some(start) = off else {
status = VIRTIO_BLK_S_IOERR;
break;
};
let Some(end) = start
.checked_add(want)
.filter(|e| *e <= self.backing_len as u64)
else {
status = VIRTIO_BLK_S_IOERR;
break;
};
unsafe {
let src = self.backing_ptr.add(start as usize);
let slice = std::slice::from_raw_parts(src, want as usize);
q.mem.write_slice(d.addr, slice);
}
bytes_written = bytes_written.saturating_add(want as u32);
off = Some(end);
}
}
VIRTIO_BLK_T_FLUSH => {
if self.writable {
unsafe {
libc::msync(
self.backing_ptr as *mut libc::c_void,
self.backing_len,
libc::MS_SYNC,
);
}
}
}
VIRTIO_BLK_T_GET_ID => {
let id = format!("{:>20}", self.name);
let bytes = id.as_bytes();
if let Some(d) = chain.get(1) {
let take = (d.len as usize).min(bytes.len());
q.mem.write_slice(d.addr, &bytes[..take]);
bytes_written += take as u32;
}
}
VIRTIO_BLK_T_OUT => {
if !self.writable {
status = VIRTIO_BLK_S_UNSUPP;
} else {
let mut off = sector.checked_mul(SECTOR_SIZE);
for d in &chain[1..chain.len() - 1] {
let n = d.len as u64;
let Some(start) = off else {
status = VIRTIO_BLK_S_IOERR;
break;
};
let Some(end) = start
.checked_add(n)
.filter(|e| *e <= self.backing_len as u64)
else {
status = VIRTIO_BLK_S_IOERR;
break;
};
let mut tmp = vec![0u8; n as usize];
q.mem.read_slice(d.addr, &mut tmp);
unsafe {
let dst = self.backing_ptr.add(start as usize);
std::ptr::copy_nonoverlapping(tmp.as_ptr(), dst, n as usize);
}
off = Some(end);
}
}
}
_ => {
status = VIRTIO_BLK_S_UNSUPP;
}
}
q.mem.write_slice(status_desc.addr, &[status]);
q.add_used(head, bytes_written);
any_used = true;
}
if any_used {
let f_opt = self.irq_raise.lock().unwrap().clone();
drop(qs);
if let Some(f) = f_opt {
f();
}
}
}
}
impl VirtioDevice for VirtioBlk {
fn device_id(&self) -> u32 {
VIRTIO_ID_BLOCK
}
fn num_queues(&self) -> usize {
1
}
fn config(&self) -> Vec<u8> {
let nsectors = (self.backing_len as u64) / SECTOR_SIZE;
nsectors.to_le_bytes().to_vec()
}
fn features(&self) -> u64 {
let mut f = VIRTIO_F_VERSION_1;
if self.writable {
f |= VIRTIO_BLK_F_FLUSH;
} else {
f |= VIRTIO_BLK_F_RO;
}
f
}
fn notify(&self, _q: u16) {
self.drain_q();
}
fn activate(&self, queues: Vec<Queue>) {
*self.queues.lock().unwrap() = queues;
self.activated
.store(true, std::sync::atomic::Ordering::Release);
eprintln!(
"[virtio-blk:{}] activated, {} sectors",
self.name,
self.backing_len as u64 / SECTOR_SIZE
);
}
fn snapshot_queues(&self) -> Vec<Queue> {
self.queues.lock().unwrap().clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::devices::virtio::queue::{GuestMem, VRING_DESC_F_NEXT, VRING_DESC_F_WRITE};
use std::io::Write;
const BASE: u64 = 0x10_0000;
const WIN: usize = 256 * 1024;
const O_DESC: u64 = 0x0000;
const O_AVAIL: u64 = 0x0800;
const O_USED: u64 = 0x1000;
const O_HDR: u64 = 0x2000;
const O_DATA: u64 = 0x3000;
const O_STATUS: u64 = 0x4000;
fn temp_path(tag: &str) -> std::path::PathBuf {
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos();
std::env::temp_dir().join(format!("sm-blk-{tag}-{}-{nanos}.img", std::process::id()))
}
fn make_rw(size: u64) -> (VirtioBlk, std::path::PathBuf) {
let path = temp_path("rw");
File::create(&path).unwrap(); let dev = VirtioBlk::open_rw("testvol", path.to_str().unwrap(), size).unwrap();
(dev, path)
}
fn make_ro(size: u64) -> (VirtioBlk, std::path::PathBuf) {
let path = temp_path("ro");
let mut f = File::create(&path).unwrap();
f.write_all(&vec![0u8; size as usize]).unwrap();
drop(f);
let dev = VirtioBlk::open_ro("testro", path.to_str().unwrap()).unwrap();
(dev, path)
}
struct Resp {
status: u8,
data: Vec<u8>,
}
fn run(dev: &VirtioBlk, req_type: u32, sector: u64, data: &[u8], data_writable: bool) -> Resp {
let mut backing = vec![0u8; WIN];
let mem = GuestMem::new(backing.as_mut_ptr(), BASE, WIN);
mem.write_u32(BASE + O_HDR, req_type);
mem.write_u32(BASE + O_HDR + 4, 0);
mem.write_u64(BASE + O_HDR + 8, sector);
mem.write_slice(BASE + O_DATA, data);
let d = |i: u64| BASE + O_DESC + i * 16;
mem.write_u64(d(0), BASE + O_HDR);
mem.write_u32(d(0) + 8, 16);
mem.write_u16(d(0) + 12, VRING_DESC_F_NEXT);
mem.write_u16(d(0) + 14, 1);
let data_flags = VRING_DESC_F_NEXT | if data_writable { VRING_DESC_F_WRITE } else { 0 };
mem.write_u64(d(1), BASE + O_DATA);
mem.write_u32(d(1) + 8, data.len() as u32);
mem.write_u16(d(1) + 12, data_flags);
mem.write_u16(d(1) + 14, 2);
mem.write_u64(d(2), BASE + O_STATUS);
mem.write_u32(d(2) + 8, 1);
mem.write_u16(d(2) + 12, VRING_DESC_F_WRITE);
mem.write_u16(d(2) + 14, 0);
mem.write_u16(BASE + O_AVAIL + 4, 0);
mem.write_u16(BASE + O_AVAIL + 2, 1);
let mut q = Queue::new(mem.clone());
q.size = 8;
q.ready = true;
q.desc_table = BASE + O_DESC;
q.avail_ring = BASE + O_AVAIL;
q.used_ring = BASE + O_USED;
dev.activate(vec![q]);
dev.notify(0);
let mut sb = [0u8; 1];
mem.read_slice(BASE + O_STATUS, &mut sb);
let mut out = vec![0u8; data.len()];
mem.read_slice(BASE + O_DATA, &mut out);
Resp {
status: sb[0],
data: out,
}
}
#[test]
fn write_then_read_round_trips() {
let (dev, path) = make_rw(64 * 1024);
let mut payload = b"SUPERMACHINE-BLK-ROUNDTRIP".to_vec();
payload.resize(512, 0);
let w = run(&dev, VIRTIO_BLK_T_OUT, 1, &payload, false);
assert_eq!(w.status, VIRTIO_BLK_S_OK, "write should succeed");
let r = run(&dev, VIRTIO_BLK_T_IN, 1, &vec![0u8; 512], true);
assert_eq!(r.status, VIRTIO_BLK_S_OK, "read should succeed");
assert_eq!(r.data, payload, "read-back must match written bytes");
std::fs::remove_file(path).ok();
}
#[test]
fn huge_sector_read_is_ioerr_not_oob() {
let (dev, path) = make_rw(64 * 1024);
let r = run(&dev, VIRTIO_BLK_T_IN, u64::MAX, &vec![0u8; 512], true);
assert_eq!(r.status, VIRTIO_BLK_S_IOERR, "overflowing read must IOERR");
std::fs::remove_file(path).ok();
}
#[test]
fn huge_sector_write_is_ioerr_not_oob() {
let (dev, path) = make_rw(64 * 1024);
let r = run(&dev, VIRTIO_BLK_T_OUT, u64::MAX, &vec![0xABu8; 512], false);
assert_eq!(r.status, VIRTIO_BLK_S_IOERR, "overflowing write must IOERR");
std::fs::remove_file(path).ok();
}
#[test]
fn sector_past_end_is_ioerr() {
let (dev, path) = make_rw(64 * 1024); let r = run(&dev, VIRTIO_BLK_T_IN, 200, &vec![0u8; 512], true);
assert_eq!(r.status, VIRTIO_BLK_S_IOERR);
std::fs::remove_file(path).ok();
}
#[test]
fn last_sector_in_bounds_is_ok() {
let (dev, path) = make_rw(64 * 1024);
let r = run(&dev, VIRTIO_BLK_T_IN, 127, &vec![0u8; 512], true);
assert_eq!(
r.status, VIRTIO_BLK_S_OK,
"last full sector must be readable"
);
let r2 = run(&dev, VIRTIO_BLK_T_IN, 128, &vec![0u8; 512], true);
assert_eq!(r2.status, VIRTIO_BLK_S_IOERR);
std::fs::remove_file(path).ok();
}
#[test]
fn readonly_device_rejects_writes() {
let (dev, path) = make_ro(64 * 1024);
let r = run(&dev, VIRTIO_BLK_T_OUT, 0, &vec![0xCDu8; 512], false);
assert_eq!(
r.status, VIRTIO_BLK_S_UNSUPP,
"RO device must reject T_OUT with UNSUPP"
);
std::fs::remove_file(path).ok();
}
}