use std::fs::{File, OpenOptions};
use std::os::unix::io::AsRawFd;
use std::sync::{Arc, Mutex};
use super::queue::Queue;
use super::{VirtioDevice, VIRTIO_ID_BLOCK};
const VIRTIO_BLK_T_IN: u32 = 0;
const VIRTIO_BLK_T_OUT: u32 = 1;
const VIRTIO_BLK_T_FLUSH: u32 = 4;
const VIRTIO_BLK_T_GET_ID: u32 = 8;
const VIRTIO_BLK_S_OK: u8 = 0;
const VIRTIO_BLK_S_IOERR: u8 = 1;
const VIRTIO_BLK_S_UNSUPP: u8 = 2;
const VIRTIO_BLK_F_SIZE_MAX: u64 = 1 << 1;
const VIRTIO_BLK_F_SEG_MAX: u64 = 1 << 2;
const VIRTIO_BLK_F_RO: u64 = 1 << 5;
const VIRTIO_BLK_F_BLK_SIZE: u64 = 1 << 6;
const VIRTIO_F_VERSION_1: u64 = 1 << 32;
const SECTOR_SIZE: u64 = 512;
pub struct VirtioBlk {
name: String,
backing_ptr: *mut u8,
backing_len: usize,
writable: bool,
queues: Mutex<Vec<Queue>>,
activated: std::sync::atomic::AtomicBool,
irq_raise: Mutex<Option<Arc<dyn Fn() + Send + Sync>>>,
}
unsafe impl Send for VirtioBlk {}
unsafe impl Sync for VirtioBlk {}
impl VirtioBlk {
pub fn open_ro(name: &str, path: &str) -> std::io::Result<Self> {
let f = File::open(path)?;
let len = f.metadata()?.len() as usize;
let p = unsafe {
libc::mmap(
std::ptr::null_mut(),
len,
libc::PROT_READ,
libc::MAP_PRIVATE,
f.as_raw_fd(),
0,
)
};
if p == libc::MAP_FAILED {
return Err(std::io::Error::last_os_error());
}
unsafe {
libc::madvise(p, len, libc::MADV_SEQUENTIAL);
}
eprintln!("[virtio-blk:{name}] mmap ro {} bytes from {path}", len);
Ok(Self {
name: name.to_string(),
backing_ptr: p as *mut u8,
backing_len: len,
writable: false,
queues: Mutex::new(Vec::new()),
activated: std::sync::atomic::AtomicBool::new(false),
irq_raise: Mutex::new(None),
})
}
pub fn open_rw(name: &str, path: &str, size_bytes: u64) -> std::io::Result<Self> {
let f = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.open(path)?;
let cur_len = f.metadata()?.len();
if cur_len < size_bytes {
f.set_len(size_bytes)?;
}
let len = f.metadata()?.len() as usize;
let p = unsafe {
libc::mmap(
std::ptr::null_mut(),
len,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_SHARED,
f.as_raw_fd(),
0,
)
};
if p == libc::MAP_FAILED {
return Err(std::io::Error::last_os_error());
}
unsafe {
libc::madvise(p, len, libc::MADV_RANDOM);
}
eprintln!("[virtio-blk:{name}] mmap rw {} bytes from {path}", len);
Ok(Self {
name: name.to_string(),
backing_ptr: p as *mut u8,
backing_len: len,
writable: true,
queues: Mutex::new(Vec::new()),
activated: std::sync::atomic::AtomicBool::new(false),
irq_raise: Mutex::new(None),
})
}
pub fn set_irq_raise(&self, f: Arc<dyn Fn() + Send + Sync>) {
*self.irq_raise.lock().unwrap() = Some(f);
}
fn drain_q(&self) {
let mut qs = self.queues.lock().unwrap();
let q = match qs.get_mut(0) {
Some(q) => q,
None => return,
};
if !q.ready {
return;
}
let mut any_used = false;
while let Some((head, chain)) = q.pop_chain() {
if chain.len() < 2 {
q.add_used(head, 0);
any_used = true;
continue;
}
let hdr = chain[0];
let status_desc = chain[chain.len() - 1];
let req_type = q.mem.read_u32(hdr.addr);
let _reserved = q.mem.read_u32(hdr.addr + 4);
let sector = q.mem.read_u64(hdr.addr + 8);
let mut status = VIRTIO_BLK_S_OK;
let mut bytes_written: u32 = 1; match req_type {
VIRTIO_BLK_T_IN => {
let mut off = (sector * SECTOR_SIZE) as usize;
for d in &chain[1..chain.len() - 1] {
let want = d.len as usize;
if off + want > self.backing_len {
status = VIRTIO_BLK_S_IOERR;
break;
}
unsafe {
let src = self.backing_ptr.add(off);
let slice = std::slice::from_raw_parts(src, want);
q.mem.write_slice(d.addr, slice);
}
bytes_written += want as u32;
off += want;
}
}
VIRTIO_BLK_T_FLUSH => {
if self.writable {
unsafe {
libc::msync(
self.backing_ptr as *mut libc::c_void,
self.backing_len,
libc::MS_SYNC,
);
}
}
}
VIRTIO_BLK_T_GET_ID => {
let id = format!("{:>20}", self.name);
let bytes = id.as_bytes();
if let Some(d) = chain.get(1) {
let take = (d.len as usize).min(bytes.len());
q.mem.write_slice(d.addr, &bytes[..take]);
bytes_written += take as u32;
}
}
VIRTIO_BLK_T_OUT => {
if !self.writable {
status = VIRTIO_BLK_S_UNSUPP;
} else {
let mut off = (sector * SECTOR_SIZE) as usize;
for d in &chain[1..chain.len() - 1] {
let n = d.len as usize;
if off + n > self.backing_len {
status = VIRTIO_BLK_S_IOERR;
break;
}
let mut tmp = vec![0u8; n];
q.mem.read_slice(d.addr, &mut tmp);
unsafe {
let dst = self.backing_ptr.add(off);
std::ptr::copy_nonoverlapping(tmp.as_ptr(), dst, n);
}
off += n;
}
}
}
_ => {
status = VIRTIO_BLK_S_UNSUPP;
}
}
q.mem.write_slice(status_desc.addr, &[status]);
q.add_used(head, bytes_written);
any_used = true;
}
if any_used {
let f_opt = self.irq_raise.lock().unwrap().clone();
drop(qs);
if let Some(f) = f_opt {
f();
}
}
}
}
impl VirtioDevice for VirtioBlk {
fn device_id(&self) -> u32 {
VIRTIO_ID_BLOCK
}
fn num_queues(&self) -> usize {
1
}
fn config(&self) -> Vec<u8> {
let nsectors = (self.backing_len as u64) / SECTOR_SIZE;
nsectors.to_le_bytes().to_vec()
}
fn features(&self) -> u64 {
let mut f = VIRTIO_F_VERSION_1;
if !self.writable {
f |= VIRTIO_BLK_F_RO;
}
f
}
fn notify(&self, _q: u16) {
self.drain_q();
}
fn activate(&self, queues: Vec<Queue>) {
*self.queues.lock().unwrap() = queues;
self.activated
.store(true, std::sync::atomic::Ordering::Release);
eprintln!(
"[virtio-blk:{}] activated, {} sectors",
self.name,
self.backing_len as u64 / SECTOR_SIZE
);
}
fn snapshot_queues(&self) -> Vec<Queue> {
self.queues.lock().unwrap().clone()
}
}