use std::fs::File;
use std::mem;
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::UnixStream;
use std::sync::{Arc, Mutex};
use super::connection::Endpoint;
use super::message::*;
use super::{Error, HandlerResult, Result};
pub trait VhostUserFrontendReqHandler {
fn handle_config_change(&self) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
fn shared_object_add(&self, _uuid: &VhostUserSharedMsg) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
fn shared_object_remove(&self, _uuid: &VhostUserSharedMsg) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
fn shared_object_lookup(
&self,
_uuid: &VhostUserSharedMsg,
_fd: &dyn AsRawFd,
) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
fn shmem_map(&self, _req: &VhostUserMMap, _fd: &dyn AsRawFd) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
fn shmem_unmap(&self, _req: &VhostUserMMap) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
}
pub trait VhostUserFrontendReqHandlerMut {
fn handle_config_change(&mut self) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
fn shared_object_add(&mut self, _uuid: &VhostUserSharedMsg) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
fn shared_object_remove(&mut self, _uuid: &VhostUserSharedMsg) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
fn shared_object_lookup(
&mut self,
_uuid: &VhostUserSharedMsg,
_fd: &dyn AsRawFd,
) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
fn shmem_map(&mut self, _req: &VhostUserMMap, _fd: &dyn AsRawFd) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
fn shmem_unmap(&mut self, _req: &VhostUserMMap) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
}
impl<S: VhostUserFrontendReqHandlerMut> VhostUserFrontendReqHandler for Mutex<S> {
fn handle_config_change(&self) -> HandlerResult<u64> {
self.lock().unwrap().handle_config_change()
}
fn shared_object_add(&self, uuid: &VhostUserSharedMsg) -> HandlerResult<u64> {
self.lock().unwrap().shared_object_add(uuid)
}
fn shared_object_remove(&self, uuid: &VhostUserSharedMsg) -> HandlerResult<u64> {
self.lock().unwrap().shared_object_remove(uuid)
}
fn shared_object_lookup(
&self,
uuid: &VhostUserSharedMsg,
fd: &dyn AsRawFd,
) -> HandlerResult<u64> {
self.lock().unwrap().shared_object_lookup(uuid, fd)
}
fn shmem_map(&self, req: &VhostUserMMap, fd: &dyn AsRawFd) -> HandlerResult<u64> {
self.lock().unwrap().shmem_map(req, fd)
}
fn shmem_unmap(&self, req: &VhostUserMMap) -> HandlerResult<u64> {
self.lock().unwrap().shmem_unmap(req)
}
}
pub struct FrontendReqHandler<S: VhostUserFrontendReqHandler> {
sub_sock: Endpoint<VhostUserMsgHeader<BackendReq>>,
tx_sock: UnixStream,
reply_ack_negotiated: bool,
backend: Arc<S>,
error: Option<i32>,
}
impl<S: VhostUserFrontendReqHandler> FrontendReqHandler<S> {
pub fn new(backend: Arc<S>) -> Result<Self> {
let (tx, rx) = UnixStream::pair().map_err(Error::SocketError)?;
Ok(FrontendReqHandler {
sub_sock: Endpoint::<VhostUserMsgHeader<BackendReq>>::from_stream(rx),
tx_sock: tx,
reply_ack_negotiated: false,
backend,
error: None,
})
}
pub fn get_tx_raw_fd(&self) -> RawFd {
self.tx_sock.as_raw_fd()
}
pub fn set_reply_ack_flag(&mut self, enable: bool) {
self.reply_ack_negotiated = enable;
}
pub fn set_failed(&mut self, error: i32) {
if error == 0 {
self.error = None;
} else {
self.error = Some(error);
}
}
pub fn handle_request(&mut self) -> Result<u64> {
self.check_state()?;
let (hdr, files) = self.sub_sock.recv_header()?;
self.check_attached_files(&hdr, &files)?;
let (size, buf) = match hdr.get_size() {
0 => (0, vec![0u8; 0]),
len => {
if len as usize > MAX_MSG_SIZE {
return Err(Error::InvalidMessage);
}
let (size2, rbuf) = self.sub_sock.recv_data(len as usize)?;
if size2 != len as usize {
return Err(Error::InvalidMessage);
}
(size2, rbuf)
}
};
let res = match hdr.get_code() {
Ok(BackendReq::CONFIG_CHANGE_MSG) => {
self.check_msg_size(&hdr, size, 0)?;
self.backend
.handle_config_change()
.map_err(Error::ReqHandlerError)
}
Ok(BackendReq::SHARED_OBJECT_ADD) => {
let msg = self.extract_msg_body::<VhostUserSharedMsg>(&hdr, size, &buf)?;
self.backend
.shared_object_add(&msg)
.map_err(Error::ReqHandlerError)
}
Ok(BackendReq::SHARED_OBJECT_REMOVE) => {
let msg = self.extract_msg_body::<VhostUserSharedMsg>(&hdr, size, &buf)?;
self.backend
.shared_object_remove(&msg)
.map_err(Error::ReqHandlerError)
}
Ok(BackendReq::SHARED_OBJECT_LOOKUP) => {
let msg = self.extract_msg_body::<VhostUserSharedMsg>(&hdr, size, &buf)?;
self.backend
.shared_object_lookup(&msg, &files.unwrap()[0])
.map_err(Error::ReqHandlerError)
}
Ok(BackendReq::SHMEM_MAP) => {
let msg = self.extract_msg_body::<VhostUserMMap>(&hdr, size, &buf)?;
self.backend
.shmem_map(&msg, &files.unwrap()[0])
.map_err(Error::ReqHandlerError)
}
Ok(BackendReq::SHMEM_UNMAP) => {
let msg = self.extract_msg_body::<VhostUserMMap>(&hdr, size, &buf)?;
self.backend
.shmem_unmap(&msg)
.map_err(Error::ReqHandlerError)
}
_ => Err(Error::InvalidMessage),
};
self.send_ack_message(&hdr, &res)?;
res
}
fn check_state(&self) -> Result<()> {
match self.error {
Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))),
None => Ok(()),
}
}
fn check_msg_size(
&self,
hdr: &VhostUserMsgHeader<BackendReq>,
size: usize,
expected: usize,
) -> Result<()> {
if hdr.get_size() as usize != expected
|| hdr.is_reply()
|| hdr.get_version() != 0x1
|| size != expected
{
return Err(Error::InvalidMessage);
}
Ok(())
}
fn check_attached_files(
&self,
hdr: &VhostUserMsgHeader<BackendReq>,
files: &Option<Vec<File>>,
) -> Result<()> {
match hdr.get_code() {
Ok(BackendReq::SHARED_OBJECT_LOOKUP | BackendReq::SHMEM_MAP) => {
match files {
Some(files) if files.len() == 1 => Ok(()),
_ => Err(Error::InvalidMessage),
}
}
_ if files.is_some() => Err(Error::InvalidMessage),
_ => Ok(()),
}
}
fn extract_msg_body<T: Sized + VhostUserMsgValidator>(
&self,
hdr: &VhostUserMsgHeader<BackendReq>,
size: usize,
buf: &[u8],
) -> Result<T> {
self.check_msg_size(hdr, size, mem::size_of::<T>())?;
let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) };
if !msg.is_valid() {
return Err(Error::InvalidMessage);
}
Ok(msg)
}
fn new_reply_header<T: Sized>(
&self,
req: &VhostUserMsgHeader<BackendReq>,
) -> Result<VhostUserMsgHeader<BackendReq>> {
if mem::size_of::<T>() > MAX_MSG_SIZE {
return Err(Error::InvalidParam);
}
self.check_state()?;
Ok(VhostUserMsgHeader::new(
req.get_code()?,
VhostUserHeaderFlag::REPLY.bits(),
mem::size_of::<T>() as u32,
))
}
fn send_ack_message(
&mut self,
req: &VhostUserMsgHeader<BackendReq>,
res: &Result<u64>,
) -> Result<()> {
if self.reply_ack_negotiated && req.is_need_reply() {
let hdr = self.new_reply_header::<VhostUserU64>(req)?;
let def_err = libc::EINVAL;
let val = match res {
Ok(n) => *n,
Err(e) => match e {
Error::ReqHandlerError(ioerr) => match ioerr.raw_os_error() {
Some(rawerr) => -rawerr as u64,
None => -def_err as u64,
},
_ => -def_err as u64,
},
};
let msg = VhostUserU64::new(val);
self.sub_sock.send_message(&hdr, &msg, None)?;
}
Ok(())
}
}
impl<S: VhostUserFrontendReqHandler> AsRawFd for FrontendReqHandler<S> {
fn as_raw_fd(&self) -> RawFd {
self.sub_sock.as_raw_fd()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
use std::io::ErrorKind;
use uuid::Uuid;
#[cfg(feature = "vhost-user-backend")]
use crate::vhost_user::Backend;
#[cfg(feature = "vhost-user-backend")]
use std::os::unix::io::FromRawFd;
struct MockFrontendReqHandler {
shared_objects: HashSet<Uuid>,
shmem_mappings: HashSet<(u64, u64)>,
}
impl MockFrontendReqHandler {
fn new() -> Self {
Self {
shared_objects: HashSet::new(),
shmem_mappings: HashSet::new(),
}
}
}
impl VhostUserFrontendReqHandlerMut for MockFrontendReqHandler {
fn shared_object_add(&mut self, uuid: &VhostUserSharedMsg) -> HandlerResult<u64> {
Ok(!self.shared_objects.insert(uuid.uuid) as u64)
}
fn shared_object_remove(&mut self, uuid: &VhostUserSharedMsg) -> HandlerResult<u64> {
Ok(!self.shared_objects.remove(&uuid.uuid) as u64)
}
fn shared_object_lookup(
&mut self,
uuid: &VhostUserSharedMsg,
_fd: &dyn AsRawFd,
) -> HandlerResult<u64> {
if self.shared_objects.contains(&uuid.uuid) {
return Ok(0);
}
Ok(1)
}
fn shmem_map(&mut self, req: &VhostUserMMap, _fd: &dyn AsRawFd) -> HandlerResult<u64> {
assert_eq!(req.shmid, 0);
if self.shmem_mappings.insert((req.shm_offset, req.len)) {
return Ok(0);
};
Ok(1)
}
fn shmem_unmap(&mut self, req: &VhostUserMMap) -> HandlerResult<u64> {
assert_eq!(req.shmid, 0);
if self.shmem_mappings.remove(&(req.shm_offset, req.len)) {
return Ok(0);
}
Ok(1)
}
}
#[test]
fn test_default_frontend_impl() {
struct Frontend;
impl VhostUserFrontendReqHandler for Frontend {}
let f = Frontend;
assert_eq!(
f.shared_object_add(&Default::default()).unwrap_err().kind(),
ErrorKind::Unsupported
);
assert_eq!(
f.shared_object_remove(&Default::default())
.unwrap_err()
.kind(),
ErrorKind::Unsupported
);
assert_eq!(
f.shared_object_lookup(&Default::default(), &0)
.unwrap_err()
.kind(),
ErrorKind::Unsupported
);
assert_eq!(
f.shmem_map(&Default::default(), &0).unwrap_err().kind(),
ErrorKind::Unsupported
);
assert_eq!(
f.shmem_unmap(&Default::default()).unwrap_err().kind(),
ErrorKind::Unsupported
);
}
#[test]
fn test_default_frontend_impl_mut() {
struct FrontendMut;
impl VhostUserFrontendReqHandlerMut for FrontendMut {}
let mut f = FrontendMut;
assert_eq!(
f.shared_object_add(&Default::default()).unwrap_err().kind(),
ErrorKind::Unsupported
);
assert_eq!(
f.shared_object_remove(&Default::default())
.unwrap_err()
.kind(),
ErrorKind::Unsupported
);
assert_eq!(
f.shared_object_lookup(&Default::default(), &0)
.unwrap_err()
.kind(),
ErrorKind::Unsupported
);
assert_eq!(
f.shmem_map(&Default::default(), &0).unwrap_err().kind(),
ErrorKind::Unsupported
);
assert_eq!(
f.shmem_unmap(&Default::default()).unwrap_err().kind(),
ErrorKind::Unsupported
);
}
#[test]
fn test_new_frontend_req_handler() {
let backend = Arc::new(Mutex::new(MockFrontendReqHandler::new()));
let mut handler = FrontendReqHandler::new(backend).unwrap();
assert!(handler.get_tx_raw_fd() >= 0);
assert!(handler.as_raw_fd() >= 0);
handler.check_state().unwrap();
assert_eq!(handler.error, None);
handler.set_failed(libc::EAGAIN);
assert_eq!(handler.error, Some(libc::EAGAIN));
handler.check_state().unwrap_err();
}
#[cfg(feature = "vhost-user-backend")]
#[test]
fn test_frontend_backend_req_handler() {
let backend = Arc::new(Mutex::new(MockFrontendReqHandler::new()));
let mut handler = FrontendReqHandler::new(backend).unwrap();
let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) };
if fd < 0 {
panic!("failed to duplicated tx fd!");
}
let stream = unsafe { UnixStream::from_raw_fd(fd) };
let backend = Backend::from_stream(stream);
let frontend_handler = std::thread::spawn(move || {
assert_eq!(handler.handle_request().unwrap(), 0);
assert_eq!(handler.handle_request().unwrap(), 1);
assert_eq!(handler.handle_request().unwrap(), 0);
assert_eq!(handler.handle_request().unwrap(), 1);
assert_eq!(handler.handle_request().unwrap(), 0);
assert_eq!(handler.handle_request().unwrap(), 1);
assert_eq!(handler.handle_request().unwrap(), 0);
assert_eq!(handler.handle_request().unwrap(), 1);
assert_eq!(handler.handle_request().unwrap(), 0);
assert_eq!(handler.handle_request().unwrap(), 0);
assert_eq!(handler.handle_request().unwrap(), 0);
});
backend.set_shared_object_flag(true);
let shobj_msg = VhostUserSharedMsg {
uuid: Uuid::new_v4(),
};
assert!(backend.shared_object_add(&shobj_msg).is_ok());
assert!(backend.shared_object_add(&shobj_msg).is_ok());
assert!(backend.shared_object_lookup(&shobj_msg, &fd).is_ok());
assert!(backend
.shared_object_lookup(
&VhostUserSharedMsg {
uuid: Uuid::new_v4(),
},
&fd,
)
.is_ok());
assert!(backend.shared_object_remove(&shobj_msg).is_ok());
assert!(backend.shared_object_remove(&shobj_msg).is_ok());
backend.set_shmem_flag(true);
let (_, some_fd_to_map) = UnixStream::pair().unwrap();
let map_request1 = VhostUserMMap {
shm_offset: 0,
len: 4096,
..Default::default()
};
let map_request2 = VhostUserMMap {
shm_offset: 4096,
len: 8192,
..Default::default()
};
backend.shmem_map(&map_request1, &some_fd_to_map).unwrap();
backend.shmem_unmap(&map_request2).unwrap();
backend.shmem_map(&map_request2, &some_fd_to_map).unwrap();
backend.shmem_unmap(&map_request2).unwrap();
backend.shmem_unmap(&map_request1).unwrap();
assert!(frontend_handler.join().is_ok());
}
#[cfg(feature = "vhost-user-backend")]
#[test]
fn test_frontend_backend_req_handler_with_ack() {
let backend = Arc::new(Mutex::new(MockFrontendReqHandler::new()));
let mut handler = FrontendReqHandler::new(backend).unwrap();
handler.set_reply_ack_flag(true);
let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) };
if fd < 0 {
panic!("failed to duplicated tx fd!");
}
let stream = unsafe { UnixStream::from_raw_fd(fd) };
let backend = Backend::from_stream(stream);
let frontend_handler = std::thread::spawn(move || {
assert_eq!(handler.handle_request().unwrap(), 0);
assert_eq!(handler.handle_request().unwrap(), 1);
assert_eq!(handler.handle_request().unwrap(), 0);
assert_eq!(handler.handle_request().unwrap(), 1);
assert_eq!(handler.handle_request().unwrap(), 0);
assert_eq!(handler.handle_request().unwrap(), 1);
assert_eq!(handler.handle_request().unwrap(), 0);
assert_eq!(handler.handle_request().unwrap(), 1);
assert_eq!(handler.handle_request().unwrap(), 0);
assert_eq!(handler.handle_request().unwrap(), 0);
assert_eq!(handler.handle_request().unwrap(), 0);
});
backend.set_reply_ack_flag(true);
backend.set_shared_object_flag(true);
let shobj_msg = VhostUserSharedMsg {
uuid: Uuid::new_v4(),
};
assert!(backend.shared_object_add(&shobj_msg).is_ok());
assert!(backend.shared_object_add(&shobj_msg).is_err());
assert!(backend.shared_object_lookup(&shobj_msg, &fd).is_ok());
assert!(backend
.shared_object_lookup(
&VhostUserSharedMsg {
uuid: Uuid::new_v4(),
},
&fd,
)
.is_err());
assert!(backend.shared_object_remove(&shobj_msg).is_ok());
assert!(backend.shared_object_remove(&shobj_msg).is_err());
backend.set_shmem_flag(true);
let (_, some_fd_to_map) = UnixStream::pair().unwrap();
let map_request1 = VhostUserMMap {
shm_offset: 0,
len: 4096,
..Default::default()
};
let map_request2 = VhostUserMMap {
shm_offset: 4096,
len: 8192,
..Default::default()
};
backend.shmem_map(&map_request1, &some_fd_to_map).unwrap();
backend.shmem_unmap(&map_request2).unwrap_err();
backend.shmem_map(&map_request2, &some_fd_to_map).unwrap();
backend.shmem_unmap(&map_request2).unwrap();
backend.shmem_unmap(&map_request1).unwrap();
assert!(frontend_handler.join().is_ok());
}
}