#[macro_use]
extern crate log;
use std::fmt::{Display, Formatter};
use std::path::Path;
use std::sync::{Arc, Mutex};
use std::thread;
use vhost::vhost_user::{BackendListener, BackendReqHandler, Error as VhostUserError, Listener};
use vm_memory::mmap::NewBitmap;
use vm_memory::{GuestMemoryAtomic, GuestMemoryMmap};
use self::handler::VhostUserHandler;
mod backend;
pub use self::backend::{VhostUserBackend, VhostUserBackendMut};
mod event_loop;
pub use self::event_loop::VringEpollHandler;
mod handler;
pub use self::handler::VhostUserHandlerError;
pub mod bitmap;
use crate::bitmap::BitmapReplace;
mod vring;
pub use self::vring::{
VringMutex, VringRwLock, VringState, VringStateGuard, VringStateMutGuard, VringT,
};
#[cfg(all(
not(RUSTDOC_disable_feature_compat_errors),
not(doc),
feature = "postcopy",
feature = "xen"
))]
compile_error!("Both `postcopy` and `xen` features can not be enabled at the same time.");
type GM<B> = GuestMemoryAtomic<GuestMemoryMmap<B>>;
#[derive(Debug)]
pub enum Error {
NewVhostUserHandler(VhostUserHandlerError),
CreateBackendListener(VhostUserError),
CreateBackendReqHandler(VhostUserError),
CreateVhostUserListener(VhostUserError),
StartDaemon(std::io::Error),
WaitDaemon(std::boxed::Box<dyn std::any::Any + std::marker::Send>),
HandleRequest(VhostUserError),
}
impl Display for Error {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
match self {
Error::NewVhostUserHandler(e) => write!(f, "cannot create vhost user handler: {e}"),
Error::CreateBackendListener(e) => write!(f, "cannot create backend listener: {e}"),
Error::CreateBackendReqHandler(e) => {
write!(f, "cannot create backend req handler: {e}")
}
Error::CreateVhostUserListener(e) => {
write!(f, "cannot create vhost-user listener: {e}")
}
Error::StartDaemon(e) => write!(f, "failed to start daemon: {e}"),
Error::WaitDaemon(_e) => write!(f, "failed to wait for daemon exit"),
Error::HandleRequest(e) => write!(f, "failed to handle request: {e}"),
}
}
}
pub type Result<T> = std::result::Result<T, Error>;
pub struct VhostUserDaemon<T: VhostUserBackend> {
name: String,
handler: Arc<Mutex<VhostUserHandler<T>>>,
main_thread: Option<thread::JoinHandle<Result<()>>>,
}
impl<T> VhostUserDaemon<T>
where
T: VhostUserBackend + Clone + 'static,
T::Bitmap: BitmapReplace + NewBitmap + Clone + Send + Sync,
T::Vring: Clone + Send + Sync,
{
pub fn new(
name: String,
backend: T,
atomic_mem: GuestMemoryAtomic<GuestMemoryMmap<T::Bitmap>>,
) -> Result<Self> {
let handler = Arc::new(Mutex::new(
VhostUserHandler::new(backend, atomic_mem).map_err(Error::NewVhostUserHandler)?,
));
Ok(VhostUserDaemon {
name,
handler,
main_thread: None,
})
}
fn start_daemon(
&mut self,
mut handler: BackendReqHandler<Mutex<VhostUserHandler<T>>>,
) -> Result<()> {
let handle = thread::Builder::new()
.name(self.name.clone())
.spawn(move || loop {
handler.handle_request().map_err(Error::HandleRequest)?;
})
.map_err(Error::StartDaemon)?;
self.main_thread = Some(handle);
Ok(())
}
pub fn start_client(&mut self, socket_path: &str) -> Result<()> {
let backend_handler = BackendReqHandler::connect(socket_path, self.handler.clone())
.map_err(Error::CreateBackendReqHandler)?;
self.start_daemon(backend_handler)
}
pub fn start(&mut self, listener: &mut Listener) -> Result<()> {
let mut backend_listener = BackendListener::new(listener, self.handler.clone())
.map_err(Error::CreateBackendListener)?;
let backend_handler = self.accept(&mut backend_listener)?;
self.start_daemon(backend_handler)
}
fn accept(
&self,
backend_listener: &mut BackendListener<Mutex<VhostUserHandler<T>>>,
) -> Result<BackendReqHandler<Mutex<VhostUserHandler<T>>>> {
loop {
match backend_listener.accept() {
Err(e) => return Err(Error::CreateBackendListener(e)),
Ok(Some(v)) => return Ok(v),
Ok(None) => continue,
}
}
}
pub fn wait(&mut self) -> Result<()> {
if let Some(handle) = self.main_thread.take() {
match handle.join().map_err(Error::WaitDaemon)? {
Ok(()) => Ok(()),
Err(Error::HandleRequest(VhostUserError::SocketBroken(_))) => Ok(()),
Err(e) => Err(e),
}
} else {
Ok(())
}
}
pub fn serve<P: AsRef<Path>>(&mut self, socket: P) -> Result<()> {
let mut listener = Listener::new(socket, true).map_err(Error::CreateVhostUserListener)?;
self.start(&mut listener)?;
let result = self.wait();
self.handler.lock().unwrap().send_exit_event();
match &result {
Err(e) => match e {
Error::HandleRequest(VhostUserError::Disconnected) => Ok(()),
Error::HandleRequest(VhostUserError::PartialMessage) => Ok(()),
_ => result,
},
_ => result,
}
}
pub fn get_epoll_handlers(&self) -> Vec<Arc<VringEpollHandler<T>>> {
self.handler.lock().unwrap().get_epoll_handlers()
}
}
#[cfg(test)]
mod tests {
use super::backend::tests::MockVhostBackend;
use super::*;
use libc::EAGAIN;
use std::os::unix::net::{UnixListener, UnixStream};
use std::sync::Barrier;
use std::time::Duration;
use vm_memory::{GuestAddress, GuestMemoryAtomic, GuestMemoryMmap};
#[test]
fn test_new_daemon() {
let mem = GuestMemoryAtomic::new(
GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0x100000), 0x10000)]).unwrap(),
);
let backend = Arc::new(Mutex::new(MockVhostBackend::new()));
let mut daemon = VhostUserDaemon::new("test".to_owned(), backend, mem).unwrap();
let handlers = daemon.get_epoll_handlers();
assert_eq!(handlers.len(), 2);
let barrier = Arc::new(Barrier::new(2));
let tmpdir = tempfile::tempdir().unwrap();
let path = tmpdir.path().join("socket");
thread::scope(|s| {
s.spawn(|| {
barrier.wait();
let socket = UnixStream::connect(&path).unwrap();
barrier.wait();
drop(socket)
});
let mut listener = Listener::new(&path, false).unwrap();
barrier.wait();
daemon.start(&mut listener).unwrap();
barrier.wait();
daemon.wait().unwrap_err();
daemon.wait().unwrap();
});
}
#[test]
fn test_new_daemon_client() {
let mem = GuestMemoryAtomic::new(
GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0x100000), 0x10000)]).unwrap(),
);
let backend = Arc::new(Mutex::new(MockVhostBackend::new()));
let mut daemon = VhostUserDaemon::new("test".to_owned(), backend, mem).unwrap();
let handlers = daemon.get_epoll_handlers();
assert_eq!(handlers.len(), 2);
let barrier = Arc::new(Barrier::new(2));
let tmpdir = tempfile::tempdir().unwrap();
let path = tmpdir.path().join("socket");
thread::scope(|s| {
s.spawn(|| {
let listener = UnixListener::bind(&path).unwrap();
barrier.wait();
let (stream, _) = listener.accept().unwrap();
barrier.wait();
drop(stream)
});
barrier.wait();
daemon
.start_client(path.as_path().to_str().unwrap())
.unwrap();
barrier.wait();
daemon.wait().unwrap_err();
daemon.wait().unwrap();
});
}
#[test]
fn test_daemon_serve() {
let mem = GuestMemoryAtomic::new(
GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0x100000), 0x10000)]).unwrap(),
);
let backend = Arc::new(Mutex::new(MockVhostBackend::new()));
let mut daemon = VhostUserDaemon::new("test".to_owned(), backend.clone(), mem).unwrap();
let tmpdir = tempfile::tempdir().unwrap();
let socket_path = tmpdir.path().join("socket");
thread::scope(|s| {
s.spawn(|| {
let _ = daemon.serve(&socket_path);
});
while !socket_path.exists() {
thread::sleep(Duration::from_millis(10));
}
for thread_id in 0..backend.queues_per_thread().len() {
let fd = backend.exit_event(thread_id).unwrap();
assert_eq!(
fd.0.consume().unwrap_err().raw_os_error().unwrap(),
EAGAIN,
"exit event should not have been raised yet!"
);
}
let socket = UnixStream::connect(&socket_path).unwrap();
drop(socket);
});
let backend = backend.lock().unwrap();
for thread_id in 0..backend.queues_per_thread().len() {
let fd = backend.exit_event(thread_id).unwrap();
assert!(fd.0.consume().is_ok(), "No exit event was raised!");
}
}
}