#![cfg_attr(feature = "_unsafe-op-in-unsafe-fn", deny(unsafe_op_in_unsafe_fn))]
#![cfg_attr(not(feature = "_unsafe-op-in-unsafe-fn"), allow(unused_unsafe))]
mod drivers;
mod properties;
mod wait;
use crate::properties::Properties;
use crate::wait::TimeoutUpdater;
use bitflags::bitflags;
use const_cstr::{const_cstr, ConstCStr};
use libc::{c_char, c_void, iovec, off_t, sigset_t};
use nix::errno::Errno;
use nix::sys::memfd::{memfd_create, MemFdCreateFlag};
use nix::sys::mman::{mmap, munmap, MapFlags, ProtFlags};
use nix::unistd::{close, ftruncate};
use std::borrow::Cow;
use std::collections::{HashSet, VecDeque};
use std::ffi::CStr;
use std::fmt;
use std::fs::File;
use std::io;
use std::os::unix::io::RawFd;
use std::os::unix::io::{FromRawFd, IntoRawFd};
use std::result;
use std::time::Duration;
use std::{error, ptr};
bitflags! {
#[repr(transparent)]
pub struct ReqFlags: u32 {
const FUA = 1 << 0;
const NO_UNMAP = 1 << 1;
const NO_FALLBACK = 1 << 2;
}
}
fn validate_req_flags(req: &Request, allowed: ReqFlags) -> Option<Completion> {
if allowed.contains(req.flags) {
None
} else if !ReqFlags::all().contains(req.flags) {
Some(Completion::for_failed_req(
req,
Errno::EINVAL,
const_cstr!("unsupported bits in request flags"),
))
} else {
let first_disallowed_flag = 1 << (req.flags & !allowed).bits().trailing_zeros();
let first_disallowed_flag = ReqFlags::from_bits(first_disallowed_flag).unwrap();
let error_msg = match first_disallowed_flag {
ReqFlags::FUA => const_cstr!("BLKIO_REQ_FUA is invalid for this request type"),
ReqFlags::NO_UNMAP => {
const_cstr!("BLKIO_REQ_NO_UNMAP is invalid for this request type")
}
ReqFlags::NO_FALLBACK => {
const_cstr!("BLKIO_REQ_NO_FALLBACK is invalid for this request type")
}
_ => panic!(),
};
Some(Completion::for_failed_req(req, Errno::EINVAL, error_msg))
}
}
#[derive(Debug)]
pub struct Error {
errno: i32,
message: Cow<'static, str>,
}
impl Error {
pub fn new<M>(errno: Errno, message: M) -> Self
where
Cow<'static, str>: From<M>,
{
Self {
errno: errno as i32,
message: message.into(),
}
}
pub fn from_io_error(io_error: io::Error, default_errno: Errno) -> Self {
Self {
errno: io_error.raw_os_error().unwrap_or(default_errno as i32),
message: io_error.to_string().into(),
}
}
pub fn from_last_os_error() -> Self {
let io_error = io::Error::last_os_error();
Self {
errno: io_error.raw_os_error().unwrap(),
message: io_error.to_string().into(),
}
}
pub fn errno(&self) -> i32 {
self.errno
}
pub fn message(&self) -> &str {
&self.message
}
}
impl error::Error for Error {}
impl From<Errno> for Error {
fn from(nix_error: Errno) -> Self {
Self {
errno: nix_error as i32,
message: nix_error.desc().into(),
}
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.message)
}
}
pub type Result<T> = result::Result<T, Error>;
#[repr(C)]
pub struct Completion {
pub user_data: usize,
pub error_msg: *const c_char,
pub ret: i32,
pub reserved_: [u8; 12],
}
impl Completion {
pub(crate) fn for_successful_req(req: &Request) -> Self {
Self {
user_data: req.user_data,
ret: 0,
error_msg: ptr::null(),
reserved_: [0; 12],
}
}
pub(crate) fn for_failed_req(req: &Request, errno: Errno, error_msg: ConstCStr) -> Self {
Self {
user_data: req.user_data,
ret: -(errno as i32),
error_msg: error_msg.as_ptr(),
reserved_: [0; 12],
}
}
}
#[derive(Clone)]
pub(crate) enum IoVecArray {
RawBorrowed { iovec: *const iovec, iovcnt: u32 },
Owned { iovec: Box<[iovec]> },
}
unsafe fn iovecarray_as_slice(iov: &IoVecArray) -> &[iovec] {
match iov {
IoVecArray::RawBorrowed { iovec, iovcnt } => unsafe {
std::slice::from_raw_parts(*iovec, *iovcnt as usize)
},
IoVecArray::Owned { iovec } => iovec.as_ref(),
}
}
impl IoVecArray {
fn from_raw_parts(iovec: *const iovec, iovcnt: u32) -> Self {
Self::RawBorrowed { iovec, iovcnt }
}
fn from_buffer(buf: *const u8, len: usize) -> Self {
Self::Owned {
iovec: Box::new([iovec {
iov_base: buf as *mut c_void,
iov_len: len,
}]),
}
}
fn as_ptr(&self) -> *const iovec {
match self {
IoVecArray::RawBorrowed { iovec, .. } => *iovec,
IoVecArray::Owned { iovec } => iovec.as_ptr(),
}
}
fn len(&self) -> u32 {
match self {
IoVecArray::RawBorrowed { iovcnt, .. } => *iovcnt,
IoVecArray::Owned { iovec } => iovec.len() as u32,
}
}
unsafe fn buffer_size(&self) -> usize {
unsafe { iovecarray_as_slice(self) }
.iter()
.map(|iov| iov.iov_len)
.sum()
}
unsafe fn offset(&self, count: usize) -> Self {
let current_array = unsafe { iovecarray_as_slice(self) };
let mut current_array_idx = current_array.iter().enumerate();
let mut len: usize = 0;
let (iov_idx, offset) = loop {
let (iov_idx, iov) = current_array_idx
.next()
.expect("the offset should be less than buffer size");
len += iov.iov_len;
if count < len {
let offset = iov.iov_len - (len - count);
break (iov_idx, offset);
}
};
let mut new_array = current_array[iov_idx..].to_vec();
let first_buf = unsafe { new_array[0].iov_base.add(offset) };
let first_len = new_array[0].iov_len - offset;
new_array[0] = iovec {
iov_base: first_buf,
iov_len: first_len,
};
Self::Owned {
iovec: new_array.into_boxed_slice(),
}
}
unsafe fn fill_with_zeroes(&mut self) {
for iovec in unsafe { iovecarray_as_slice(self) } {
unsafe {
ptr::write_bytes(iovec.iov_base.cast::<u8>(), 0, iovec.iov_len);
}
}
}
}
#[derive(Clone)]
pub(crate) enum RequestTypeArgs {
Read {
start: u64,
buf: *mut u8,
len: usize,
},
Write {
start: u64,
buf: *const u8,
len: usize,
},
Readv {
start: u64,
iovec: IoVecArray,
},
Writev {
start: u64,
iovec: IoVecArray,
},
WriteZeroes {
start: u64,
len: u64,
},
Discard {
start: u64,
len: u64,
},
Flush,
}
#[derive(Clone)]
pub(crate) struct Request {
pub(crate) args: RequestTypeArgs,
pub(crate) user_data: usize,
pub(crate) flags: ReqFlags,
}
pub(crate) trait Queue {
fn is_poll_queue(&self) -> bool;
fn get_completion_fd(&self) -> Option<RawFd>;
fn set_completion_fd_enabled(&mut self, enabled: bool);
fn try_enqueue(
&mut self,
completion_backlog: &mut CompletionBacklog,
req: Request,
) -> result::Result<(), Request>;
fn do_io(
&mut self,
request_backlog: &mut RequestBacklog,
completion_backlog: &mut CompletionBacklog,
completions: &mut [std::mem::MaybeUninit<Completion>],
min_completions: usize,
timeout_updater: Option<&mut TimeoutUpdater>,
sig: Option<&sigset_t>,
) -> Result<usize>;
}
pub(crate) struct RequestBacklog {
reqs: VecDeque<Request>,
}
impl RequestBacklog {
fn new() -> RequestBacklog {
RequestBacklog {
reqs: VecDeque::new(),
}
}
pub(crate) fn len(&self) -> usize {
self.reqs.len()
}
fn enqueue_or_backlog(
&mut self,
queue: &mut dyn Queue,
completion_backlog: &mut CompletionBacklog,
req: Request,
) {
if self.reqs.is_empty() {
if let Err(req) = queue.try_enqueue(completion_backlog, req) {
self.reqs.push_back(req);
}
} else {
self.reqs.push_back(req);
}
}
pub(crate) fn process(
&mut self,
queue: &mut dyn Queue,
completion_backlog: &mut CompletionBacklog,
) -> usize {
let mut count = 0;
while let Some(req) = self.reqs.pop_front() {
if let Err(req) = queue.try_enqueue(completion_backlog, req) {
self.reqs.push_front(req); break;
}
count += 1;
}
count
}
}
pub(crate) struct CompletionBacklog {
completions: VecDeque<Completion>,
completion_fd: Option<RawFd>,
}
impl CompletionBacklog {
fn new(completion_fd: Option<RawFd>) -> Self {
Self {
completions: VecDeque::new(),
completion_fd,
}
}
pub(crate) fn len(&self) -> usize {
self.completions.len()
}
fn signal_completion_fd(&mut self) {
if let Some(fd) = self.completion_fd {
let val: u64 = 1;
let valp: *const u64 = &val;
unsafe { libc::write(fd, valp.cast(), std::mem::size_of::<u64>()) };
}
}
pub(crate) fn push(&mut self, completion: Completion) {
self.completions.push_back(completion);
self.signal_completion_fd();
}
pub(crate) fn fill_completions(
&mut self,
completions: &mut [std::mem::MaybeUninit<Completion>],
) -> usize {
let mut n = 0;
for c in completions.iter_mut().take(self.completions.len()) {
let val = self.completions.pop_front().unwrap();
unsafe { c.as_mut_ptr().write(val) };
n += 1;
}
n
}
pub(crate) fn unfill_completions(
&mut self,
completions: &mut [std::mem::MaybeUninit<Completion>],
count: usize,
) {
for c in completions[..count].iter().rev() {
self.completions.push_front(unsafe { c.as_ptr().read() });
}
self.signal_completion_fd();
}
}
pub struct Blkioq {
queue: Box<dyn Queue>,
request_backlog: RequestBacklog,
completion_backlog: CompletionBacklog,
}
impl Blkioq {
pub(crate) fn new(queue: Box<dyn Queue>) -> Self {
let completion_fd = queue.get_completion_fd();
Blkioq {
queue,
request_backlog: RequestBacklog::new(),
completion_backlog: CompletionBacklog::new(completion_fd),
}
}
pub fn get_completion_fd(&self) -> Option<RawFd> {
self.queue.get_completion_fd()
}
pub fn set_completion_fd_enabled(&mut self, enabled: bool) {
self.queue.set_completion_fd_enabled(enabled);
}
fn enqueue(&mut self, allowed_flags: ReqFlags, req: Request) {
if let Some(completion) = validate_req_flags(&req, allowed_flags) {
self.completion_backlog.push(completion);
return;
}
self.request_backlog
.enqueue_or_backlog(&mut *self.queue, &mut self.completion_backlog, req)
}
pub fn read(
&mut self,
start: u64,
buf: *mut u8,
len: usize,
user_data: usize,
flags: ReqFlags,
) {
self.enqueue(
ReqFlags::empty(),
Request {
args: RequestTypeArgs::Read { start, buf, len },
user_data,
flags,
},
)
}
pub fn write(
&mut self,
start: u64,
buf: *const u8,
len: usize,
user_data: usize,
flags: ReqFlags,
) {
self.enqueue(
ReqFlags::FUA,
Request {
args: RequestTypeArgs::Write { start, buf, len },
user_data,
flags,
},
)
}
pub fn readv(
&mut self,
start: u64,
iovec: *const iovec,
iovcnt: u32,
user_data: usize,
flags: ReqFlags,
) {
let req = Request {
args: RequestTypeArgs::Readv {
start,
iovec: IoVecArray::from_raw_parts(iovec, iovcnt),
},
user_data,
flags,
};
if iovcnt > i32::MAX as u32 {
self.completion_backlog.push(Completion::for_failed_req(
&req,
Errno::EINVAL,
const_cstr!("iovcnt must be non-negative and fit in a signed 32-bit integer"),
));
return;
}
self.enqueue(ReqFlags::empty(), req)
}
pub fn writev(
&mut self,
start: u64,
iovec: *const iovec,
iovcnt: u32,
user_data: usize,
flags: ReqFlags,
) {
let req = Request {
args: RequestTypeArgs::Writev {
start,
iovec: IoVecArray::from_raw_parts(iovec, iovcnt),
},
user_data,
flags,
};
if iovcnt > i32::MAX as u32 {
self.completion_backlog.push(Completion::for_failed_req(
&req,
Errno::EINVAL,
const_cstr!("iovcnt must be non-negative and fit in a signed 32-bit integer"),
));
return;
}
self.enqueue(ReqFlags::FUA, req)
}
pub fn write_zeroes(&mut self, start: u64, len: u64, user_data: usize, flags: ReqFlags) {
self.enqueue(
ReqFlags::NO_UNMAP | ReqFlags::NO_FALLBACK,
Request {
args: RequestTypeArgs::WriteZeroes { start, len },
user_data,
flags,
},
)
}
pub fn discard(&mut self, start: u64, len: u64, user_data: usize, flags: ReqFlags) {
self.enqueue(
ReqFlags::empty(),
Request {
args: RequestTypeArgs::Discard { start, len },
user_data,
flags,
},
)
}
pub fn flush(&mut self, user_data: usize, flags: ReqFlags) {
self.enqueue(
ReqFlags::empty(),
Request {
args: RequestTypeArgs::Flush,
user_data,
flags,
},
);
}
pub fn do_io(
&mut self,
completions: &mut [std::mem::MaybeUninit<Completion>],
min_completions: usize,
timeout: Option<&mut Duration>,
sig: Option<&sigset_t>,
) -> Result<usize> {
if sig.is_some() && self.queue.is_poll_queue() {
return Err(Error::new(
Errno::ENOTSUP,
"blkioq_do_io_interruptible() is not supported on poll queues",
));
}
let mut timeout_updater = timeout.as_deref().map(|t| TimeoutUpdater::new(*t));
let result = self.queue.do_io(
&mut self.request_backlog,
&mut self.completion_backlog,
completions,
min_completions,
timeout_updater.as_mut(),
sig,
);
if let Some(timeout) = timeout {
*timeout = timeout_updater.unwrap().next();
};
result
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, PartialOrd)]
enum State {
Created, Connected, Started, }
#[derive(Clone, Copy, Eq, Hash, PartialEq)]
pub struct MemoryRegion {
pub addr: usize,
pub iova: u64,
pub len: usize,
pub fd: RawFd,
pub fd_offset: i64,
pub flags: u32,
}
trait Driver: Properties {
fn state(&self) -> State;
fn connect(&mut self) -> Result<()>;
fn start(&mut self) -> Result<()>;
fn alloc_mem_region(&mut self, len: usize) -> Result<MemoryRegion> {
if self.state() < State::Connected {
return Err(properties::error_must_be_connected());
}
let align = self.get_u64("mem-region-alignment")? as usize;
if len % align != 0 {
return Err(Error::new(
Errno::EINVAL,
format!("len {} violates mem-region-alignment {}", len, align),
));
}
let fd = memfd_create(
CStr::from_bytes_with_nul(b"libblkio-buf\0").unwrap(),
MemFdCreateFlag::MFD_CLOEXEC,
)?;
let file = unsafe { File::from_raw_fd(fd) };
ftruncate(fd, len as off_t)?;
let addr = unsafe {
mmap(
ptr::null_mut(),
len,
ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
MapFlags::MAP_SHARED,
fd,
0,
)?
};
if (addr as usize) % align != 0 {
unsafe { munmap(addr, len)? };
return Err(Error::new(
Errno::EOVERFLOW,
format!(
"Address {} violates mem-region-alignment {}",
addr as usize, align,
),
));
}
Ok(MemoryRegion {
addr: addr as usize,
iova: 0,
len,
fd: file.into_raw_fd(),
fd_offset: 0,
flags: 0,
})
}
fn free_mem_region(&mut self, region: &MemoryRegion) {
let _ = unsafe { munmap(region.addr as *mut c_void, region.len) };
let _ = close(region.fd);
}
fn map_mem_region(&mut self, region: &MemoryRegion) -> Result<()>;
fn unmap_mem_region(&mut self, region: &MemoryRegion);
fn get_queue(&mut self, index: usize) -> Result<&mut Blkioq>;
fn get_poll_queue(&mut self, index: usize) -> Result<&mut Blkioq>;
}
pub struct Blkio {
driver: Box<dyn Driver>,
allocated_regions: HashSet<MemoryRegion>,
mapped_regions: HashSet<MemoryRegion>, }
impl Blkio {
pub fn new(driver_name: &str) -> Result<Blkio> {
let driver: Box<dyn Driver> = match driver_name {
#[cfg(feature = "io_uring")]
"io_uring" => Box::new(drivers::iouring::IoUring::new()),
#[cfg(feature = "nvme-io_uring")]
"nvme-io_uring" => Box::new(drivers::nvme_io_uring::NvmeIoUring::new()),
#[cfg(feature = "virtio-blk-vfio-pci")]
drivers::virtio_blk::VFIO_PCI_DRIVER => {
Box::new(drivers::virtio_blk::VirtioBlk::new(driver_name))
}
#[cfg(feature = "virtio-blk-vhost-user")]
drivers::virtio_blk::VHOST_USER_DRIVER => {
Box::new(drivers::virtio_blk::VirtioBlk::new(driver_name))
}
#[cfg(feature = "virtio-blk-vhost-vdpa")]
drivers::virtio_blk::VHOST_VDPA_DRIVER => {
Box::new(drivers::virtio_blk::VirtioBlk::new(driver_name))
}
_ => return Err(Error::new(Errno::ENOENT, "Unknown driver name")),
};
Ok(Blkio {
driver,
allocated_regions: HashSet::new(),
mapped_regions: HashSet::new(),
})
}
pub fn connect(&mut self) -> Result<()> {
self.driver.connect()
}
pub fn start(&mut self) -> Result<()> {
self.driver.start()
}
pub fn get_bool(&self, name: &str) -> Result<bool> {
self.driver.get_bool(name)
}
pub fn get_i32(&self, name: &str) -> Result<i32> {
self.driver.get_i32(name)
}
pub fn get_str(&self, name: &str) -> Result<String> {
self.driver.get_str(name)
}
pub fn get_u64(&self, name: &str) -> Result<u64> {
self.driver.get_u64(name)
}
pub fn set_bool(&mut self, name: &str, value: bool) -> Result<()> {
self.driver.set_bool(name, value)
}
pub fn set_i32(&mut self, name: &str, value: i32) -> Result<()> {
self.driver.set_i32(name, value)
}
pub fn set_str(&mut self, name: &str, value: &str) -> Result<()> {
self.driver.set_str(name, value)
}
pub fn set_u64(&mut self, name: &str, value: u64) -> Result<()> {
self.driver.set_u64(name, value)
}
pub fn alloc_mem_region(&mut self, len: usize) -> Result<MemoryRegion> {
let region = self.driver.alloc_mem_region(len)?;
assert!(self.allocated_regions.insert(region));
Ok(region)
}
pub fn free_mem_region(&mut self, region: &MemoryRegion) {
assert!(!self.mapped_regions.contains(region));
assert!(self.allocated_regions.remove(region));
self.driver.free_mem_region(region);
}
pub fn map_mem_region(&mut self, region: &MemoryRegion) -> Result<()> {
let align = self.get_u64("mem-region-alignment")? as usize;
if region.addr % align != 0 {
return Err(Error::new(
Errno::EINVAL,
format!(
"addr {:#x} violates mem-region-alignment {}",
region.addr, align
),
));
}
if region.len % align != 0 {
return Err(Error::new(
Errno::EINVAL,
format!(
"len {:#x} violates mem-region-alignment {}",
region.len, align
),
));
}
if self.mapped_regions.contains(region) {
return Err(Error::new(Errno::EINVAL, "memory region already mapped"));
}
self.driver.map_mem_region(region)?;
self.mapped_regions.insert(*region);
Ok(())
}
pub fn unmap_mem_region(&mut self, region: &MemoryRegion) {
if self.mapped_regions.remove(region) {
self.driver.unmap_mem_region(region);
}
}
pub fn get_queue(&mut self, index: usize) -> Result<&mut Blkioq> {
self.driver.get_queue(index)
}
pub fn get_poll_queue(&mut self, index: usize) -> Result<&mut Blkioq> {
self.driver.get_poll_queue(index)
}
}
impl Drop for Blkio {
fn drop(&mut self) {
for region in &self.mapped_regions {
self.driver.unmap_mem_region(region);
}
for region in &self.allocated_regions {
self.driver.free_mem_region(region);
}
}
}