use std::io::Error as IOError;
pub mod message;
pub use self::message::{VhostUserProtocolFeatures, VhostUserVirtioFeatures};
mod connection;
pub use self::connection::Listener;
#[cfg(feature = "vhost-user-frontend")]
mod frontend;
#[cfg(feature = "vhost-user-frontend")]
pub use self::frontend::{Frontend, VhostUserFrontend};
#[cfg(feature = "vhost-user")]
mod frontend_req_handler;
#[cfg(feature = "vhost-user")]
pub use self::frontend_req_handler::{
FrontendReqHandler, VhostUserFrontendReqHandler, VhostUserFrontendReqHandlerMut,
};
#[cfg(feature = "vhost-user-backend")]
mod backend;
#[cfg(feature = "vhost-user-backend")]
pub use self::backend::BackendListener;
#[cfg(feature = "vhost-user-backend")]
mod backend_req_handler;
#[cfg(feature = "vhost-user-backend")]
pub use self::backend_req_handler::{
BackendReqHandler, VhostUserBackendReqHandler, VhostUserBackendReqHandlerMut,
};
#[cfg(feature = "vhost-user-backend")]
mod backend_req;
#[cfg(feature = "vhost-user-backend")]
pub use self::backend_req::Backend;
mod gpu_backend_req;
pub mod gpu_message;
pub use self::gpu_backend_req::GpuBackend;
#[derive(Debug)]
pub enum Error {
InvalidParam,
InvalidOperation(&'static str),
InactiveFeature(VhostUserVirtioFeatures),
InactiveOperation(VhostUserProtocolFeatures),
InvalidMessage,
PartialMessage,
Disconnected,
OversizedMsg,
IncorrectFds,
SocketConnect(std::io::Error),
SocketError(std::io::Error),
SocketBroken(std::io::Error),
SocketRetry(std::io::Error),
BackendInternalError,
FrontendInternalError,
FeatureMismatch,
ReqHandlerError(IOError),
MemFdCreateError,
FileTruncateError,
MemFdSealError,
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Error::InvalidParam => write!(f, "invalid parameters"),
Error::InvalidOperation(reason) => write!(f, "invalid operation: {reason}"),
Error::InactiveFeature(bits) => write!(f, "inactive feature: {}", bits.bits()),
Error::InactiveOperation(bits) => {
write!(f, "inactive protocol operation: {}", bits.bits())
}
Error::InvalidMessage => write!(f, "invalid message"),
Error::PartialMessage => write!(f, "partial message"),
Error::Disconnected => write!(f, "peer disconnected"),
Error::OversizedMsg => write!(f, "oversized message"),
Error::IncorrectFds => write!(f, "wrong number of attached fds"),
Error::SocketError(e) => write!(f, "socket error: {e}"),
Error::SocketConnect(e) => write!(f, "can't connect to peer: {e}"),
Error::SocketBroken(e) => write!(f, "socket is broken: {e}"),
Error::SocketRetry(e) => write!(f, "temporary socket error: {e}"),
Error::BackendInternalError => write!(f, "backend internal error"),
Error::FrontendInternalError => write!(f, "Frontend internal error"),
Error::FeatureMismatch => write!(f, "virtio/protocol features mismatch"),
Error::ReqHandlerError(e) => write!(f, "handler failed to handle request: {e}"),
Error::MemFdCreateError => {
write!(f, "handler failed to allocate memfd during get_inflight_fd")
}
Error::FileTruncateError => {
write!(f, "handler failed to truncate memfd during get_inflight_fd")
}
Error::MemFdSealError => write!(
f,
"handler failed to apply seals to memfd during get_inflight_fd"
),
}
}
}
impl std::error::Error for Error {}
impl Error {
pub fn should_reconnect(&self) -> bool {
match *self {
Error::PartialMessage => true,
Error::SocketBroken(_) => true,
Error::BackendInternalError => true,
Error::FrontendInternalError => true,
Error::SocketRetry(_) => false,
Error::Disconnected => false,
Error::InvalidParam | Error::InvalidOperation(_) => false,
Error::InactiveFeature(_) | Error::InactiveOperation(_) => false,
Error::InvalidMessage | Error::IncorrectFds | Error::OversizedMsg => false,
Error::SocketError(_) | Error::SocketConnect(_) => false,
Error::FeatureMismatch => false,
Error::ReqHandlerError(_) => false,
Error::MemFdCreateError | Error::FileTruncateError | Error::MemFdSealError => false,
}
}
}
impl std::convert::From<vmm_sys_util::errno::Error> for Error {
#[allow(unreachable_patterns)] #[allow(clippy::match_overlapping_arm)] fn from(err: vmm_sys_util::errno::Error) -> Self {
match err.errno() {
libc::EAGAIN => Error::SocketRetry(IOError::from_raw_os_error(libc::EAGAIN)),
libc::EWOULDBLOCK => Error::SocketRetry(IOError::from_raw_os_error(libc::EWOULDBLOCK)),
libc::EINTR => Error::SocketRetry(IOError::from_raw_os_error(libc::EINTR)),
libc::ENOBUFS => Error::SocketRetry(IOError::from_raw_os_error(libc::ENOBUFS)),
libc::ENOMEM => Error::SocketRetry(IOError::from_raw_os_error(libc::ENOMEM)),
libc::ECONNRESET => Error::SocketBroken(IOError::from_raw_os_error(libc::ECONNRESET)),
libc::EPIPE => Error::SocketBroken(IOError::from_raw_os_error(libc::EPIPE)),
libc::EACCES => Error::SocketConnect(IOError::from_raw_os_error(libc::EACCES)),
e => Error::SocketError(IOError::from_raw_os_error(e)),
}
}
}
pub type Result<T> = std::result::Result<T, Error>;
pub type HandlerResult<T> = std::result::Result<T, IOError>;
#[cfg(any(feature = "vhost-user-backend", feature = "vhost-user-frontend"))]
pub(crate) fn take_single_file(files: Option<Vec<std::fs::File>>) -> Option<std::fs::File> {
let mut files = files?;
if files.len() != 1 {
return None;
}
Some(files.swap_remove(0))
}
macro_rules! enum_value {
(
$(#[$meta:meta])*
$vis:vis enum $enum:ident: $T:tt {
$(
$(#[$variant_meta:meta])*
$variant:ident $(= $val:expr)?,
)*
}
) => {
#[repr($T)]
$(#[$meta])*
$vis enum $enum {
$($(#[$variant_meta])* $variant $(= $val)?,)*
}
impl std::convert::TryFrom<$T> for $enum {
type Error = ();
fn try_from(v: $T) -> std::result::Result<Self, Self::Error> {
match v {
$(v if v == $enum::$variant as $T => Ok($enum::$variant),)*
_ => Err(()),
}
}
}
impl std::convert::From<$enum> for $T {
fn from(v: $enum) -> $T {
v as $T
}
}
}
}
use enum_value;
#[cfg(all(test, feature = "vhost-user-backend"))]
mod dummy_backend;
#[cfg(all(test, feature = "vhost-user-frontend", feature = "vhost-user-backend"))]
mod tests {
use message::VhostUserSharedMsg;
use std::fs::File;
use std::os::unix::io::AsRawFd;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Barrier, Mutex};
use std::thread;
use uuid::Uuid;
use vmm_sys_util::rand::rand_alphanumerics;
use vmm_sys_util::tempfile::TempFile;
use super::dummy_backend::{DummyBackendReqHandler, VIRTIO_FEATURES};
use super::message::*;
use super::*;
use crate::backend::VhostBackend;
use crate::{VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData};
fn temp_path() -> PathBuf {
PathBuf::from(format!(
"/tmp/vhost_test_{}",
rand_alphanumerics(8).to_str().unwrap()
))
}
fn create_backend<P, S>(path: P, backend: Arc<S>) -> (Frontend, BackendReqHandler<S>)
where
P: AsRef<Path>,
S: VhostUserBackendReqHandler,
{
let mut listener = Listener::new(&path, true).unwrap();
let mut backend_listener = BackendListener::new(&mut listener, backend).unwrap();
let frontend = Frontend::connect(&path, 1).unwrap();
(frontend, backend_listener.accept().unwrap().unwrap())
}
#[test]
fn create_dummy_backend() {
let backend = Arc::new(Mutex::new(DummyBackendReqHandler::new()));
backend.set_owner().unwrap();
assert!(backend.set_owner().is_err());
}
#[test]
fn test_set_owner() {
let backend_be = Arc::new(Mutex::new(DummyBackendReqHandler::new()));
let path = temp_path();
let (frontend, mut backend) = create_backend(path, backend_be.clone());
assert!(!backend_be.lock().unwrap().owned);
frontend.set_owner().unwrap();
backend.handle_request().unwrap();
assert!(backend_be.lock().unwrap().owned);
frontend.set_owner().unwrap();
assert!(backend.handle_request().is_err());
assert!(backend_be.lock().unwrap().owned);
}
#[test]
fn test_set_features() {
let mbar = Arc::new(Barrier::new(2));
let sbar = mbar.clone();
let path = temp_path();
let backend_be = Arc::new(Mutex::new(DummyBackendReqHandler::new()));
let (mut frontend, mut backend) = create_backend(path, backend_be.clone());
thread::spawn(move || {
backend.handle_request().unwrap();
assert!(backend_be.lock().unwrap().owned);
backend.handle_request().unwrap();
backend.handle_request().unwrap();
assert_eq!(
backend_be.lock().unwrap().acked_features,
VIRTIO_FEATURES & !0x1
);
backend.handle_request().unwrap();
backend.handle_request().unwrap();
assert_eq!(
backend_be.lock().unwrap().acked_protocol_features,
VhostUserProtocolFeatures::all().bits()
);
sbar.wait();
});
frontend.set_owner().unwrap();
let features = frontend.get_features().unwrap();
assert_eq!(features, VIRTIO_FEATURES);
frontend.set_features(VIRTIO_FEATURES & !0x1).unwrap();
let features = frontend.get_protocol_features().unwrap();
assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
frontend.set_protocol_features(features).unwrap();
mbar.wait();
}
#[test]
fn test_frontend_backend_process() {
let mbar = Arc::new(Barrier::new(2));
let sbar = mbar.clone();
let path = temp_path();
let backend_be = Arc::new(Mutex::new(DummyBackendReqHandler::new()));
let (mut frontend, mut backend) = create_backend(path, backend_be.clone());
thread::spawn(move || {
backend.handle_request().unwrap();
assert!(backend_be.lock().unwrap().owned);
backend.handle_request().unwrap();
backend.handle_request().unwrap();
assert_eq!(
backend_be.lock().unwrap().acked_features,
VIRTIO_FEATURES & !0x1
);
backend.handle_request().unwrap();
backend.handle_request().unwrap();
let mut features = VhostUserProtocolFeatures::all();
if !cfg!(feature = "xen") {
features.remove(VhostUserProtocolFeatures::XEN_MMAP);
}
assert_eq!(
backend_be.lock().unwrap().acked_protocol_features,
features.bits()
);
backend.handle_request().unwrap();
backend.handle_request().unwrap();
backend.handle_request().unwrap();
backend.handle_request().unwrap();
backend.handle_request().unwrap();
backend.handle_request().unwrap();
backend.handle_request().unwrap();
backend.handle_request().unwrap();
backend.handle_request().unwrap();
backend.handle_request().unwrap();
backend.handle_request().unwrap_err();
backend.handle_request().unwrap();
backend.handle_request().unwrap();
backend.handle_request().unwrap();
backend.handle_request().unwrap();
backend.handle_request().unwrap();
backend.handle_request().unwrap();
backend.handle_request().unwrap();
backend.handle_request().unwrap();
backend.handle_request().unwrap();
sbar.wait();
});
frontend.set_owner().unwrap();
let features = frontend.get_features().unwrap();
assert_eq!(features, VIRTIO_FEATURES);
frontend.set_features(VIRTIO_FEATURES & !0x1).unwrap();
let mut features = frontend.get_protocol_features().unwrap();
assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
if !cfg!(feature = "xen") {
features.remove(VhostUserProtocolFeatures::XEN_MMAP);
}
frontend.set_protocol_features(features).unwrap();
let (inflight_info, inflight_file) = frontend
.get_inflight_fd(&VhostUserInflight {
num_queues: 2,
queue_size: 256,
..Default::default()
})
.unwrap();
frontend
.set_inflight_fd(&inflight_info, inflight_file.as_raw_fd())
.unwrap();
frontend
.get_shared_object(&VhostUserSharedMsg {
uuid: Uuid::new_v4(),
})
.unwrap();
let num = frontend.get_queue_num().unwrap();
assert_eq!(num, 2);
let eventfd = vmm_sys_util::eventfd::EventFd::new(0).unwrap();
let mem = [VhostUserMemoryRegionInfo::new(
0,
0x10_0000,
0,
0,
eventfd.as_raw_fd(),
)];
frontend.set_mem_table(&mem).unwrap();
frontend
.set_config(0x100, VhostUserConfigFlags::WRITABLE, &[0xa5u8; 4])
.unwrap();
let buf = [0x0u8; 4];
let (reply_body, reply_payload) = frontend
.get_config(0x100, 4, VhostUserConfigFlags::empty(), &buf)
.unwrap();
let offset = reply_body.offset;
assert_eq!(offset, 0x100);
assert_eq!(&reply_payload, &[0xa5; 4]);
frontend.set_backend_request_fd(&eventfd).unwrap();
frontend.set_vring_enable(0, true).unwrap();
frontend
.set_log_base(
0,
Some(VhostUserDirtyLogRegion {
mmap_size: 0x1000,
mmap_offset: 0,
mmap_handle: eventfd.as_raw_fd(),
}),
)
.unwrap();
frontend.set_log_fd(eventfd.as_raw_fd()).unwrap();
frontend.set_vring_num(0, 256).unwrap();
frontend.set_vring_base(0, 0).unwrap();
let config = VringConfigData {
queue_max_size: 256,
queue_size: 128,
flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits(),
desc_table_addr: 0x1000,
used_ring_addr: 0x2000,
avail_ring_addr: 0x3000,
log_addr: Some(0x4000),
};
frontend.set_vring_addr(0, &config).unwrap();
frontend.set_vring_call(0, &eventfd).unwrap();
frontend.set_vring_kick(0, &eventfd).unwrap();
frontend.set_vring_err(0, &eventfd).unwrap();
let max_mem_slots = frontend.get_max_mem_slots().unwrap();
assert_eq!(max_mem_slots, 509);
let region_file: File = TempFile::new().unwrap().into_file();
let region =
VhostUserMemoryRegionInfo::new(0x10_0000, 0x10_0000, 0, 0, region_file.as_raw_fd());
frontend.add_mem_region(®ion).unwrap();
frontend.remove_mem_region(®ion).unwrap();
mbar.wait();
}
#[test]
fn test_error_display() {
assert_eq!(format!("{}", Error::InvalidParam), "invalid parameters");
assert_eq!(
format!("{}", Error::InvalidOperation("reason")),
"invalid operation: reason"
);
}
#[test]
fn test_should_reconnect() {
assert!(Error::PartialMessage.should_reconnect());
assert!(Error::BackendInternalError.should_reconnect());
assert!(Error::FrontendInternalError.should_reconnect());
assert!(!Error::InvalidParam.should_reconnect());
assert!(!Error::InvalidOperation("reason").should_reconnect());
assert!(
!Error::InactiveFeature(VhostUserVirtioFeatures::PROTOCOL_FEATURES).should_reconnect()
);
assert!(!Error::InactiveOperation(VhostUserProtocolFeatures::all()).should_reconnect());
assert!(!Error::InvalidMessage.should_reconnect());
assert!(!Error::IncorrectFds.should_reconnect());
assert!(!Error::OversizedMsg.should_reconnect());
assert!(!Error::FeatureMismatch.should_reconnect());
}
#[test]
fn test_error_from_sys_util_error() {
let e: Error = vmm_sys_util::errno::Error::new(libc::EAGAIN).into();
if let Error::SocketRetry(e1) = e {
assert_eq!(e1.raw_os_error().unwrap(), libc::EAGAIN);
} else {
panic!("invalid error code conversion!");
}
}
}