use std::fs::File;
use std::mem;
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::UnixStream;
use std::path::Path;
use std::sync::{Arc, Mutex, MutexGuard};
use vm_memory::ByteValued;
use vmm_sys_util::eventfd::EventFd;
use super::connection::Endpoint;
use super::message::*;
use super::{take_single_file, Error as VhostUserError, Result as VhostUserResult};
use crate::backend::{
VhostBackend, VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData,
};
use crate::{Error, Result};
pub trait VhostUserMaster: VhostBackend {
fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>;
fn set_protocol_features(&mut self, features: VhostUserProtocolFeatures) -> Result<()>;
fn get_queue_num(&mut self) -> Result<u64>;
fn set_vring_enable(&mut self, queue_index: usize, enable: bool) -> Result<()>;
fn get_config(
&mut self,
offset: u32,
size: u32,
flags: VhostUserConfigFlags,
buf: &[u8],
) -> Result<(VhostUserConfig, VhostUserConfigPayload)>;
fn set_config(&mut self, offset: u32, flags: VhostUserConfigFlags, buf: &[u8]) -> Result<()>;
fn set_slave_request_fd(&mut self, fd: &dyn AsRawFd) -> Result<()>;
fn get_inflight_fd(
&mut self,
inflight: &VhostUserInflight,
) -> Result<(VhostUserInflight, File)>;
fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, fd: RawFd) -> Result<()>;
fn get_max_mem_slots(&mut self) -> Result<u64>;
fn add_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()>;
fn remove_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()>;
}
fn error_code<T>(err: VhostUserError) -> Result<T> {
Err(Error::VhostUserProtocol(err))
}
#[derive(Clone)]
pub struct Master {
node: Arc<Mutex<MasterInternal>>,
}
impl Master {
fn new(ep: Endpoint<MasterReq>, max_queue_num: u64) -> Self {
Master {
node: Arc::new(Mutex::new(MasterInternal {
main_sock: ep,
virtio_features: 0,
acked_virtio_features: 0,
protocol_features: 0,
acked_protocol_features: 0,
protocol_features_ready: false,
max_queue_num,
error: None,
hdr_flags: VhostUserHeaderFlag::empty(),
})),
}
}
fn node(&self) -> MutexGuard<MasterInternal> {
self.node.lock().unwrap()
}
pub fn from_stream(sock: UnixStream, max_queue_num: u64) -> Self {
Self::new(Endpoint::<MasterReq>::from_stream(sock), max_queue_num)
}
pub fn connect<P: AsRef<Path>>(path: P, max_queue_num: u64) -> Result<Self> {
let mut retry_count = 5;
let endpoint = loop {
match Endpoint::<MasterReq>::connect(&path) {
Ok(endpoint) => break Ok(endpoint),
Err(e) => match &e {
VhostUserError::SocketConnect(why) => {
if why.kind() == std::io::ErrorKind::ConnectionRefused && retry_count > 0 {
std::thread::sleep(std::time::Duration::from_millis(100));
retry_count -= 1;
continue;
} else {
break Err(e);
}
}
_ => break Err(e),
},
}
}?;
Ok(Self::new(endpoint, max_queue_num))
}
pub fn set_hdr_flags(&self, flags: VhostUserHeaderFlag) {
let mut node = self.node();
node.hdr_flags = flags;
}
}
impl VhostBackend for Master {
fn get_features(&self) -> Result<u64> {
let mut node = self.node();
let hdr = node.send_request_header(MasterReq::GET_FEATURES, None)?;
let val = node.recv_reply::<VhostUserU64>(&hdr)?;
node.virtio_features = val.value;
Ok(node.virtio_features)
}
fn set_features(&self, features: u64) -> Result<()> {
let mut node = self.node();
let val = VhostUserU64::new(features);
let hdr = node.send_request_with_body(MasterReq::SET_FEATURES, &val, None)?;
node.acked_virtio_features = features & node.virtio_features;
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
fn set_owner(&self) -> Result<()> {
let mut node = self.node();
let hdr = node.send_request_header(MasterReq::SET_OWNER, None)?;
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
fn reset_owner(&self) -> Result<()> {
let mut node = self.node();
let hdr = node.send_request_header(MasterReq::RESET_OWNER, None)?;
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> {
if regions.is_empty() || regions.len() > MAX_ATTACHED_FD_ENTRIES {
return error_code(VhostUserError::InvalidParam);
}
let mut ctx = VhostUserMemoryContext::new();
for region in regions.iter() {
if region.memory_size == 0 || region.mmap_handle < 0 {
return error_code(VhostUserError::InvalidParam);
}
let reg = VhostUserMemoryRegion {
guest_phys_addr: region.guest_phys_addr,
memory_size: region.memory_size,
user_addr: region.userspace_addr,
mmap_offset: region.mmap_offset,
};
ctx.append(®, region.mmap_handle);
}
let mut node = self.node();
let body = VhostUserMemory::new(ctx.regions.len() as u32);
let (_, payload, _) = unsafe { ctx.regions.align_to::<u8>() };
let hdr = node.send_request_with_payload(
MasterReq::SET_MEM_TABLE,
&body,
payload,
Some(ctx.fds.as_slice()),
)?;
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
#[allow(clippy::unnecessary_unwrap)]
fn set_log_base(&self, base: u64, region: Option<VhostUserDirtyLogRegion>) -> Result<()> {
let mut node = self.node();
let val = VhostUserU64::new(base);
if node.acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0
&& region.is_some()
{
let region = region.unwrap();
let log = VhostUserLog {
mmap_size: region.mmap_size,
mmap_offset: region.mmap_offset,
};
let hdr = node.send_request_with_body(
MasterReq::SET_LOG_BASE,
&log,
Some(&[region.mmap_handle]),
)?;
node.wait_for_ack(&hdr).map_err(|e| e.into())
} else {
let _ = node.send_request_with_body(MasterReq::SET_LOG_BASE, &val, None)?;
Ok(())
}
}
fn set_log_fd(&self, fd: RawFd) -> Result<()> {
let mut node = self.node();
let fds = [fd];
let hdr = node.send_request_header(MasterReq::SET_LOG_FD, Some(&fds))?;
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> {
let mut node = self.node();
if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
let val = VhostUserVringState::new(queue_index as u32, num.into());
let hdr = node.send_request_with_body(MasterReq::SET_VRING_NUM, &val, None)?;
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> {
let mut node = self.node();
if queue_index as u64 >= node.max_queue_num
|| config_data.flags & !(VhostUserVringAddrFlags::all().bits()) != 0
{
return error_code(VhostUserError::InvalidParam);
}
let val = VhostUserVringAddr::from_config_data(queue_index as u32, config_data);
let hdr = node.send_request_with_body(MasterReq::SET_VRING_ADDR, &val, None)?;
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()> {
let mut node = self.node();
if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
let val = VhostUserVringState::new(queue_index as u32, base.into());
let hdr = node.send_request_with_body(MasterReq::SET_VRING_BASE, &val, None)?;
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
fn get_vring_base(&self, queue_index: usize) -> Result<u32> {
let mut node = self.node();
if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
let req = VhostUserVringState::new(queue_index as u32, 0);
let hdr = node.send_request_with_body(MasterReq::GET_VRING_BASE, &req, None)?;
let reply = node.recv_reply::<VhostUserVringState>(&hdr)?;
Ok(reply.num)
}
fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
let mut node = self.node();
if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
let hdr = node.send_fd_for_vring(MasterReq::SET_VRING_CALL, queue_index, fd.as_raw_fd())?;
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
let mut node = self.node();
if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
let hdr = node.send_fd_for_vring(MasterReq::SET_VRING_KICK, queue_index, fd.as_raw_fd())?;
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
let mut node = self.node();
if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
let hdr = node.send_fd_for_vring(MasterReq::SET_VRING_ERR, queue_index, fd.as_raw_fd())?;
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
}
impl VhostUserMaster for Master {
fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> {
let mut node = self.node();
node.check_feature(VhostUserVirtioFeatures::PROTOCOL_FEATURES)?;
let hdr = node.send_request_header(MasterReq::GET_PROTOCOL_FEATURES, None)?;
let val = node.recv_reply::<VhostUserU64>(&hdr)?;
node.protocol_features = val.value;
match VhostUserProtocolFeatures::from_bits(node.protocol_features) {
Some(val) => Ok(val),
None => error_code(VhostUserError::InvalidMessage),
}
}
fn set_protocol_features(&mut self, features: VhostUserProtocolFeatures) -> Result<()> {
let mut node = self.node();
node.check_feature(VhostUserVirtioFeatures::PROTOCOL_FEATURES)?;
let val = VhostUserU64::new(features.bits());
let hdr = node.send_request_with_body(MasterReq::SET_PROTOCOL_FEATURES, &val, None)?;
node.acked_protocol_features = features.bits();
node.protocol_features_ready = true;
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
fn get_queue_num(&mut self) -> Result<u64> {
let mut node = self.node();
node.check_proto_feature(VhostUserProtocolFeatures::MQ)?;
let hdr = node.send_request_header(MasterReq::GET_QUEUE_NUM, None)?;
let val = node.recv_reply::<VhostUserU64>(&hdr)?;
if val.value > VHOST_USER_MAX_VRINGS {
return error_code(VhostUserError::InvalidMessage);
}
node.max_queue_num = val.value;
Ok(node.max_queue_num)
}
fn set_vring_enable(&mut self, queue_index: usize, enable: bool) -> Result<()> {
let mut node = self.node();
if node.acked_virtio_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0 {
return error_code(VhostUserError::InactiveFeature(
VhostUserVirtioFeatures::PROTOCOL_FEATURES,
));
} else if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
let flag = if enable { 1 } else { 0 };
let val = VhostUserVringState::new(queue_index as u32, flag);
let hdr = node.send_request_with_body(MasterReq::SET_VRING_ENABLE, &val, None)?;
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
fn get_config(
&mut self,
offset: u32,
size: u32,
flags: VhostUserConfigFlags,
buf: &[u8],
) -> Result<(VhostUserConfig, VhostUserConfigPayload)> {
let body = VhostUserConfig::new(offset, size, flags);
if !body.is_valid() {
return error_code(VhostUserError::InvalidParam);
}
let mut node = self.node();
node.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?;
let hdr = node.send_request_with_payload(MasterReq::GET_CONFIG, &body, buf, None)?;
let (body_reply, buf_reply, rfds) =
node.recv_reply_with_payload::<VhostUserConfig>(&hdr)?;
if rfds.is_some() {
return error_code(VhostUserError::InvalidMessage);
} else if body_reply.size == 0 {
return error_code(VhostUserError::SlaveInternalError);
} else if body_reply.size != body.size
|| body_reply.size as usize != buf.len()
|| body_reply.offset != body.offset
{
return error_code(VhostUserError::InvalidMessage);
}
Ok((body_reply, buf_reply))
}
fn set_config(&mut self, offset: u32, flags: VhostUserConfigFlags, buf: &[u8]) -> Result<()> {
if buf.len() > MAX_MSG_SIZE {
return error_code(VhostUserError::InvalidParam);
}
let body = VhostUserConfig::new(offset, buf.len() as u32, flags);
if !body.is_valid() {
return error_code(VhostUserError::InvalidParam);
}
let mut node = self.node();
node.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?;
let hdr = node.send_request_with_payload(MasterReq::SET_CONFIG, &body, buf, None)?;
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
fn set_slave_request_fd(&mut self, fd: &dyn AsRawFd) -> Result<()> {
let mut node = self.node();
node.check_proto_feature(VhostUserProtocolFeatures::SLAVE_REQ)?;
let fds = [fd.as_raw_fd()];
let hdr = node.send_request_header(MasterReq::SET_SLAVE_REQ_FD, Some(&fds))?;
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
fn get_inflight_fd(
&mut self,
inflight: &VhostUserInflight,
) -> Result<(VhostUserInflight, File)> {
let mut node = self.node();
node.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?;
let hdr = node.send_request_with_body(MasterReq::GET_INFLIGHT_FD, inflight, None)?;
let (inflight, files) = node.recv_reply_with_files::<VhostUserInflight>(&hdr)?;
match take_single_file(files) {
Some(file) => Ok((inflight, file)),
None => error_code(VhostUserError::IncorrectFds),
}
}
fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, fd: RawFd) -> Result<()> {
let mut node = self.node();
node.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?;
if inflight.mmap_size == 0 || inflight.num_queues == 0 || inflight.queue_size == 0 || fd < 0
{
return error_code(VhostUserError::InvalidParam);
}
let hdr = node.send_request_with_body(MasterReq::SET_INFLIGHT_FD, inflight, Some(&[fd]))?;
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
fn get_max_mem_slots(&mut self) -> Result<u64> {
let mut node = self.node();
node.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?;
let hdr = node.send_request_header(MasterReq::GET_MAX_MEM_SLOTS, None)?;
let val = node.recv_reply::<VhostUserU64>(&hdr)?;
Ok(val.value)
}
fn add_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()> {
let mut node = self.node();
node.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?;
if region.memory_size == 0 || region.mmap_handle < 0 {
return error_code(VhostUserError::InvalidParam);
}
let body = VhostUserSingleMemoryRegion::new(
region.guest_phys_addr,
region.memory_size,
region.userspace_addr,
region.mmap_offset,
);
let fds = [region.mmap_handle];
let hdr = node.send_request_with_body(MasterReq::ADD_MEM_REG, &body, Some(&fds))?;
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
fn remove_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()> {
let mut node = self.node();
node.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?;
if region.memory_size == 0 {
return error_code(VhostUserError::InvalidParam);
}
let body = VhostUserSingleMemoryRegion::new(
region.guest_phys_addr,
region.memory_size,
region.userspace_addr,
region.mmap_offset,
);
let hdr = node.send_request_with_body(MasterReq::REM_MEM_REG, &body, None)?;
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
}
impl AsRawFd for Master {
fn as_raw_fd(&self) -> RawFd {
let node = self.node();
node.main_sock.as_raw_fd()
}
}
struct VhostUserMemoryContext {
regions: VhostUserMemoryPayload,
fds: Vec<RawFd>,
}
impl VhostUserMemoryContext {
pub fn new() -> Self {
VhostUserMemoryContext {
regions: VhostUserMemoryPayload::new(),
fds: Vec::new(),
}
}
pub fn append(&mut self, region: &VhostUserMemoryRegion, fd: RawFd) {
self.regions.push(*region);
self.fds.push(fd);
}
}
struct MasterInternal {
main_sock: Endpoint<MasterReq>,
virtio_features: u64,
acked_virtio_features: u64,
protocol_features: u64,
acked_protocol_features: u64,
protocol_features_ready: bool,
max_queue_num: u64,
error: Option<i32>,
hdr_flags: VhostUserHeaderFlag,
}
impl MasterInternal {
fn send_request_header(
&mut self,
code: MasterReq,
fds: Option<&[RawFd]>,
) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> {
self.check_state()?;
let hdr = self.new_request_header(code, 0);
self.main_sock.send_header(&hdr, fds)?;
Ok(hdr)
}
fn send_request_with_body<T: Sized>(
&mut self,
code: MasterReq,
msg: &T,
fds: Option<&[RawFd]>,
) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> {
if mem::size_of::<T>() > MAX_MSG_SIZE {
return Err(VhostUserError::InvalidParam);
}
self.check_state()?;
let hdr = self.new_request_header(code, mem::size_of::<T>() as u32);
self.main_sock.send_message(&hdr, msg, fds)?;
Ok(hdr)
}
fn send_request_with_payload<T: Sized>(
&mut self,
code: MasterReq,
msg: &T,
payload: &[u8],
fds: Option<&[RawFd]>,
) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> {
let len = mem::size_of::<T>() + payload.len();
if len > MAX_MSG_SIZE {
return Err(VhostUserError::InvalidParam);
}
if let Some(fd_arr) = fds {
if fd_arr.len() > MAX_ATTACHED_FD_ENTRIES {
return Err(VhostUserError::InvalidParam);
}
}
self.check_state()?;
let hdr = self.new_request_header(code, len as u32);
self.main_sock
.send_message_with_payload(&hdr, msg, payload, fds)?;
Ok(hdr)
}
fn send_fd_for_vring(
&mut self,
code: MasterReq,
queue_index: usize,
fd: RawFd,
) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> {
if queue_index as u64 >= self.max_queue_num {
return Err(VhostUserError::InvalidParam);
}
self.check_state()?;
let msg = VhostUserU64::new(queue_index as u64);
let hdr = self.new_request_header(code, mem::size_of::<VhostUserU64>() as u32);
self.main_sock.send_message(&hdr, &msg, Some(&[fd]))?;
Ok(hdr)
}
fn recv_reply<T: ByteValued + Sized + VhostUserMsgValidator>(
&mut self,
hdr: &VhostUserMsgHeader<MasterReq>,
) -> VhostUserResult<T> {
if mem::size_of::<T>() > MAX_MSG_SIZE || hdr.is_reply() {
return Err(VhostUserError::InvalidParam);
}
self.check_state()?;
let (reply, body, rfds) = self.main_sock.recv_body::<T>()?;
if !reply.is_reply_for(hdr) || rfds.is_some() || !body.is_valid() {
return Err(VhostUserError::InvalidMessage);
}
Ok(body)
}
fn recv_reply_with_files<T: ByteValued + Sized + VhostUserMsgValidator>(
&mut self,
hdr: &VhostUserMsgHeader<MasterReq>,
) -> VhostUserResult<(T, Option<Vec<File>>)> {
if mem::size_of::<T>() > MAX_MSG_SIZE || hdr.is_reply() {
return Err(VhostUserError::InvalidParam);
}
self.check_state()?;
let (reply, body, files) = self.main_sock.recv_body::<T>()?;
if !reply.is_reply_for(hdr) || files.is_none() || !body.is_valid() {
return Err(VhostUserError::InvalidMessage);
}
Ok((body, files))
}
fn recv_reply_with_payload<T: ByteValued + Sized + VhostUserMsgValidator>(
&mut self,
hdr: &VhostUserMsgHeader<MasterReq>,
) -> VhostUserResult<(T, Vec<u8>, Option<Vec<File>>)> {
if mem::size_of::<T>() > MAX_MSG_SIZE
|| hdr.get_size() as usize <= mem::size_of::<T>()
|| hdr.get_size() as usize > MAX_MSG_SIZE
|| hdr.is_reply()
{
return Err(VhostUserError::InvalidParam);
}
self.check_state()?;
let mut buf: Vec<u8> = vec![0; hdr.get_size() as usize - mem::size_of::<T>()];
let (reply, body, bytes, files) = self.main_sock.recv_payload_into_buf::<T>(&mut buf)?;
if !reply.is_reply_for(hdr)
|| reply.get_size() as usize != mem::size_of::<T>() + bytes
|| files.is_some()
|| !body.is_valid()
|| bytes != buf.len()
{
return Err(VhostUserError::InvalidMessage);
}
Ok((body, buf, files))
}
fn wait_for_ack(&mut self, hdr: &VhostUserMsgHeader<MasterReq>) -> VhostUserResult<()> {
if self.acked_protocol_features & VhostUserProtocolFeatures::REPLY_ACK.bits() == 0
|| !hdr.is_need_reply()
{
return Ok(());
}
self.check_state()?;
let (reply, body, rfds) = self.main_sock.recv_body::<VhostUserU64>()?;
if !reply.is_reply_for(hdr) || rfds.is_some() || !body.is_valid() {
return Err(VhostUserError::InvalidMessage);
}
if body.value != 0 {
return Err(VhostUserError::SlaveInternalError);
}
Ok(())
}
fn check_feature(&self, feat: VhostUserVirtioFeatures) -> VhostUserResult<()> {
if self.virtio_features & feat.bits() != 0 {
Ok(())
} else {
Err(VhostUserError::InactiveFeature(feat))
}
}
fn check_proto_feature(&self, feat: VhostUserProtocolFeatures) -> VhostUserResult<()> {
if self.acked_protocol_features & feat.bits() != 0 {
Ok(())
} else {
Err(VhostUserError::InactiveOperation(feat))
}
}
fn check_state(&self) -> VhostUserResult<()> {
match self.error {
Some(e) => Err(VhostUserError::SocketBroken(
std::io::Error::from_raw_os_error(e),
)),
None => Ok(()),
}
}
#[inline]
fn new_request_header(&self, request: MasterReq, size: u32) -> VhostUserMsgHeader<MasterReq> {
VhostUserMsgHeader::new(request, self.hdr_flags.bits() | 0x1, size)
}
}
#[cfg(test)]
mod tests {
use super::super::connection::Listener;
use super::*;
use vmm_sys_util::rand::rand_alphanumerics;
use std::path::PathBuf;
fn temp_path() -> PathBuf {
PathBuf::from(format!(
"/tmp/vhost_test_{}",
rand_alphanumerics(8).to_str().unwrap()
))
}
fn create_pair<P: AsRef<Path>>(path: P) -> (Master, Endpoint<MasterReq>) {
let listener = Listener::new(&path, true).unwrap();
listener.set_nonblocking(true).unwrap();
let master = Master::connect(path, 2).unwrap();
let slave = listener.accept().unwrap().unwrap();
(master, Endpoint::from_stream(slave))
}
#[test]
fn create_master() {
let path = temp_path();
let listener = Listener::new(&path, true).unwrap();
listener.set_nonblocking(true).unwrap();
let master = Master::connect(&path, 1).unwrap();
let mut slave = Endpoint::<MasterReq>::from_stream(listener.accept().unwrap().unwrap());
assert!(master.as_raw_fd() > 0);
master.set_owner().unwrap();
master.reset_owner().unwrap();
let (hdr, rfds) = slave.recv_header().unwrap();
assert_eq!(hdr.get_code(), MasterReq::SET_OWNER);
assert_eq!(hdr.get_size(), 0);
assert_eq!(hdr.get_version(), 0x1);
assert!(rfds.is_none());
let (hdr, rfds) = slave.recv_header().unwrap();
assert_eq!(hdr.get_code(), MasterReq::RESET_OWNER);
assert_eq!(hdr.get_size(), 0);
assert_eq!(hdr.get_version(), 0x1);
assert!(rfds.is_none());
}
#[test]
fn test_create_failure() {
let path = temp_path();
let _ = Listener::new(&path, true).unwrap();
let _ = Listener::new(&path, false).is_err();
assert!(Master::connect(&path, 1).is_err());
let listener = Listener::new(&path, true).unwrap();
assert!(Listener::new(&path, false).is_err());
listener.set_nonblocking(true).unwrap();
let _master = Master::connect(&path, 1).unwrap();
let _slave = listener.accept().unwrap().unwrap();
}
#[test]
fn test_features() {
let path = temp_path();
let (master, mut peer) = create_pair(&path);
master.set_owner().unwrap();
let (hdr, rfds) = peer.recv_header().unwrap();
assert_eq!(hdr.get_code(), MasterReq::SET_OWNER);
assert_eq!(hdr.get_size(), 0);
assert_eq!(hdr.get_version(), 0x1);
assert!(rfds.is_none());
let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8);
let msg = VhostUserU64::new(0x15);
peer.send_message(&hdr, &msg, None).unwrap();
let features = master.get_features().unwrap();
assert_eq!(features, 0x15u64);
let (_hdr, rfds) = peer.recv_header().unwrap();
assert!(rfds.is_none());
let hdr = VhostUserMsgHeader::new(MasterReq::SET_FEATURES, 0x4, 8);
let msg = VhostUserU64::new(0x15);
peer.send_message(&hdr, &msg, None).unwrap();
master.set_features(0x15).unwrap();
let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap();
assert!(rfds.is_none());
let val = msg.value;
assert_eq!(val, 0x15);
let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8);
let msg = 0x15u32;
peer.send_message(&hdr, &msg, None).unwrap();
assert!(master.get_features().is_err());
}
#[test]
fn test_protocol_features() {
let path = temp_path();
let (mut master, mut peer) = create_pair(&path);
master.set_owner().unwrap();
let (hdr, rfds) = peer.recv_header().unwrap();
assert_eq!(hdr.get_code(), MasterReq::SET_OWNER);
assert!(rfds.is_none());
assert!(master.get_protocol_features().is_err());
assert!(master
.set_protocol_features(VhostUserProtocolFeatures::all())
.is_err());
let vfeatures = 0x15 | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8);
let msg = VhostUserU64::new(vfeatures);
peer.send_message(&hdr, &msg, None).unwrap();
let features = master.get_features().unwrap();
assert_eq!(features, vfeatures);
let (_hdr, rfds) = peer.recv_header().unwrap();
assert!(rfds.is_none());
master.set_features(vfeatures).unwrap();
let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap();
assert!(rfds.is_none());
let val = msg.value;
assert_eq!(val, vfeatures);
let pfeatures = VhostUserProtocolFeatures::all();
let hdr = VhostUserMsgHeader::new(MasterReq::GET_PROTOCOL_FEATURES, 0x4, 8);
let msg = VhostUserU64::new(pfeatures.bits());
peer.send_message(&hdr, &msg, None).unwrap();
let features = master.get_protocol_features().unwrap();
assert_eq!(features, pfeatures);
let (_hdr, rfds) = peer.recv_header().unwrap();
assert!(rfds.is_none());
master.set_protocol_features(pfeatures).unwrap();
let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap();
assert!(rfds.is_none());
let val = msg.value;
assert_eq!(val, pfeatures.bits());
let hdr = VhostUserMsgHeader::new(MasterReq::SET_PROTOCOL_FEATURES, 0x4, 8);
let msg = VhostUserU64::new(pfeatures.bits());
peer.send_message(&hdr, &msg, None).unwrap();
assert!(master.get_protocol_features().is_err());
}
#[test]
fn test_master_set_config_negative() {
let path = temp_path();
let (mut master, _peer) = create_pair(&path);
let buf = vec![0x0; MAX_MSG_SIZE + 1];
master
.set_config(0x100, VhostUserConfigFlags::WRITABLE, &buf[0..4])
.unwrap_err();
{
let mut node = master.node();
node.virtio_features = 0xffff_ffff;
node.acked_virtio_features = 0xffff_ffff;
node.protocol_features = 0xffff_ffff;
node.acked_protocol_features = 0xffff_ffff;
}
master
.set_config(0, VhostUserConfigFlags::WRITABLE, &buf[0..4])
.unwrap();
master
.set_config(
VHOST_USER_CONFIG_SIZE,
VhostUserConfigFlags::WRITABLE,
&buf[0..4],
)
.unwrap_err();
master
.set_config(0x1000, VhostUserConfigFlags::WRITABLE, &buf[0..4])
.unwrap_err();
master
.set_config(
0x100,
unsafe { VhostUserConfigFlags::from_bits_unchecked(0xffff_ffff) },
&buf[0..4],
)
.unwrap_err();
master
.set_config(VHOST_USER_CONFIG_SIZE, VhostUserConfigFlags::WRITABLE, &buf)
.unwrap_err();
master
.set_config(VHOST_USER_CONFIG_SIZE, VhostUserConfigFlags::WRITABLE, &[])
.unwrap_err();
}
fn create_pair2() -> (Master, Endpoint<MasterReq>) {
let path = temp_path();
let (master, peer) = create_pair(&path);
{
let mut node = master.node();
node.virtio_features = 0xffff_ffff;
node.acked_virtio_features = 0xffff_ffff;
node.protocol_features = 0xffff_ffff;
node.acked_protocol_features = 0xffff_ffff;
}
(master, peer)
}
#[test]
fn test_master_get_config_negative0() {
let (mut master, mut peer) = create_pair2();
let buf = vec![0x0; MAX_MSG_SIZE + 1];
let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
.unwrap();
assert!(master
.get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
.is_ok());
hdr.set_code(MasterReq::GET_FEATURES);
peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
.unwrap();
assert!(master
.get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
.is_err());
hdr.set_code(MasterReq::GET_CONFIG);
}
#[test]
fn test_master_get_config_negative1() {
let (mut master, mut peer) = create_pair2();
let buf = vec![0x0; MAX_MSG_SIZE + 1];
let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
.unwrap();
assert!(master
.get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
.is_ok());
hdr.set_reply(false);
peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
.unwrap();
assert!(master
.get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
.is_err());
}
#[test]
fn test_master_get_config_negative2() {
let (mut master, mut peer) = create_pair2();
let buf = vec![0x0; MAX_MSG_SIZE + 1];
let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
.unwrap();
assert!(master
.get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
.is_ok());
}
#[test]
fn test_master_get_config_negative3() {
let (mut master, mut peer) = create_pair2();
let buf = vec![0x0; MAX_MSG_SIZE + 1];
let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
.unwrap();
assert!(master
.get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
.is_ok());
msg.offset = 0;
peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
.unwrap();
assert!(master
.get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
.is_err());
}
#[test]
fn test_master_get_config_negative4() {
let (mut master, mut peer) = create_pair2();
let buf = vec![0x0; MAX_MSG_SIZE + 1];
let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
.unwrap();
assert!(master
.get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
.is_ok());
msg.offset = 0x101;
peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
.unwrap();
assert!(master
.get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
.is_err());
}
#[test]
fn test_master_get_config_negative5() {
let (mut master, mut peer) = create_pair2();
let buf = vec![0x0; MAX_MSG_SIZE + 1];
let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
.unwrap();
assert!(master
.get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
.is_ok());
msg.offset = (MAX_MSG_SIZE + 1) as u32;
peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
.unwrap();
assert!(master
.get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
.is_err());
}
#[test]
fn test_master_get_config_negative6() {
let (mut master, mut peer) = create_pair2();
let buf = vec![0x0; MAX_MSG_SIZE + 1];
let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
.unwrap();
assert!(master
.get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
.is_ok());
msg.size = 6;
peer.send_message_with_payload(&hdr, &msg, &buf[0..6], None)
.unwrap();
assert!(master
.get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
.is_err());
}
#[test]
fn test_maset_set_mem_table_failure() {
let (master, _peer) = create_pair2();
master.set_mem_table(&[]).unwrap_err();
let tables = vec![VhostUserMemoryRegionInfo::default(); MAX_ATTACHED_FD_ENTRIES + 1];
master.set_mem_table(&tables).unwrap_err();
}
}