use std::fs::File;
use std::mem;
use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
use std::os::unix::net::UnixStream;
use std::slice;
use std::sync::{Arc, Mutex};
use vm_memory::ByteValued;
use super::backend_req::Backend;
use super::connection::Endpoint;
use super::gpu_backend_req::GpuBackend;
use super::message::*;
use super::{take_single_file, Error, Result};
#[allow(missing_docs)]
pub trait VhostUserBackendReqHandler {
fn set_owner(&self) -> Result<()>;
fn reset_owner(&self) -> Result<()>;
fn reset_device(&self) -> Result<()>;
fn get_features(&self) -> Result<u64>;
fn set_features(&self, features: u64) -> Result<()>;
fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>;
fn set_vring_num(&self, index: u32, num: u32) -> Result<()>;
fn set_vring_addr(
&self,
index: u32,
flags: VhostUserVringAddrFlags,
descriptor: u64,
used: u64,
available: u64,
log: u64,
) -> Result<()>;
fn set_vring_base(&self, index: u32, base: u32) -> Result<()>;
fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState>;
fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()>;
fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()>;
fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()>;
fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>;
fn set_protocol_features(&self, features: u64) -> Result<()>;
fn get_queue_num(&self) -> Result<u64>;
fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()>;
fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>;
fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
fn set_backend_req_fd(&self, _backend: Backend) {}
fn set_gpu_socket(&self, _gpu_backend: GpuBackend) -> Result<()>;
fn get_shared_object(&self, uuid: VhostUserSharedMsg) -> Result<File>;
fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>;
fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>;
fn get_max_mem_slots(&self) -> Result<u64>;
fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>;
fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>;
fn set_device_state_fd(
&self,
direction: VhostTransferStateDirection,
phase: VhostTransferStatePhase,
fd: File,
) -> Result<Option<File>>;
fn check_device_state(&self) -> Result<()>;
fn get_shmem_config(&self) -> Result<VhostUserShMemConfig>;
#[cfg(feature = "postcopy")]
fn postcopy_advice(&self) -> Result<File>;
#[cfg(feature = "postcopy")]
fn postcopy_listen(&self) -> Result<()>;
#[cfg(feature = "postcopy")]
fn postcopy_end(&self) -> Result<()>;
fn set_log_base(&self, log: &VhostUserLog, file: File) -> Result<()>;
}
#[allow(missing_docs)]
pub trait VhostUserBackendReqHandlerMut {
fn set_owner(&mut self) -> Result<()>;
fn reset_owner(&mut self) -> Result<()>;
fn reset_device(&mut self) -> Result<()>;
fn get_features(&mut self) -> Result<u64>;
fn set_features(&mut self, features: u64) -> Result<()>;
fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>;
fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>;
fn set_vring_addr(
&mut self,
index: u32,
flags: VhostUserVringAddrFlags,
descriptor: u64,
used: u64,
available: u64,
log: u64,
) -> Result<()>;
fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>;
fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>;
fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>;
fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>;
fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>;
fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>;
fn set_protocol_features(&mut self, features: u64) -> Result<()>;
fn get_queue_num(&mut self) -> Result<u64>;
fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>;
fn get_config(
&mut self,
offset: u32,
size: u32,
flags: VhostUserConfigFlags,
) -> Result<Vec<u8>>;
fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
fn set_backend_req_fd(&mut self, _backend: Backend) {}
fn set_gpu_socket(&mut self, _gpu_backend: GpuBackend) -> Result<()>;
fn get_shared_object(&mut self, uuid: VhostUserSharedMsg) -> Result<File>;
fn get_inflight_fd(
&mut self,
inflight: &VhostUserInflight,
) -> Result<(VhostUserInflight, File)>;
fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>;
fn get_max_mem_slots(&mut self) -> Result<u64>;
fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>;
fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>;
fn set_device_state_fd(
&mut self,
direction: VhostTransferStateDirection,
phase: VhostTransferStatePhase,
fd: File,
) -> Result<Option<File>>;
fn check_device_state(&mut self) -> Result<()>;
fn get_shmem_config(&mut self) -> Result<VhostUserShMemConfig>;
#[cfg(feature = "postcopy")]
fn postcopy_advice(&mut self) -> Result<File>;
#[cfg(feature = "postcopy")]
fn postcopy_listen(&mut self) -> Result<()>;
#[cfg(feature = "postcopy")]
fn postcopy_end(&mut self) -> Result<()>;
fn set_log_base(&mut self, log: &VhostUserLog, file: File) -> Result<()>;
}
impl<T: VhostUserBackendReqHandlerMut> VhostUserBackendReqHandler for Mutex<T> {
fn set_owner(&self) -> Result<()> {
self.lock().unwrap().set_owner()
}
fn reset_owner(&self) -> Result<()> {
self.lock().unwrap().reset_owner()
}
fn reset_device(&self) -> Result<()> {
self.lock().unwrap().reset_device()
}
fn get_features(&self) -> Result<u64> {
self.lock().unwrap().get_features()
}
fn set_features(&self, features: u64) -> Result<()> {
self.lock().unwrap().set_features(features)
}
fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()> {
self.lock().unwrap().set_mem_table(ctx, files)
}
fn set_vring_num(&self, index: u32, num: u32) -> Result<()> {
self.lock().unwrap().set_vring_num(index, num)
}
fn set_vring_addr(
&self,
index: u32,
flags: VhostUserVringAddrFlags,
descriptor: u64,
used: u64,
available: u64,
log: u64,
) -> Result<()> {
self.lock()
.unwrap()
.set_vring_addr(index, flags, descriptor, used, available, log)
}
fn set_vring_base(&self, index: u32, base: u32) -> Result<()> {
self.lock().unwrap().set_vring_base(index, base)
}
fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState> {
self.lock().unwrap().get_vring_base(index)
}
fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()> {
self.lock().unwrap().set_vring_kick(index, fd)
}
fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()> {
self.lock().unwrap().set_vring_call(index, fd)
}
fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()> {
self.lock().unwrap().set_vring_err(index, fd)
}
fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures> {
self.lock().unwrap().get_protocol_features()
}
fn set_protocol_features(&self, features: u64) -> Result<()> {
self.lock().unwrap().set_protocol_features(features)
}
fn get_queue_num(&self) -> Result<u64> {
self.lock().unwrap().get_queue_num()
}
fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()> {
self.lock().unwrap().set_vring_enable(index, enable)
}
fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>> {
self.lock().unwrap().get_config(offset, size, flags)
}
fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> {
self.lock().unwrap().set_config(offset, buf, flags)
}
fn set_backend_req_fd(&self, backend: Backend) {
self.lock().unwrap().set_backend_req_fd(backend)
}
fn set_gpu_socket(&self, gpu_backend: GpuBackend) -> Result<()> {
self.lock().unwrap().set_gpu_socket(gpu_backend)
}
fn get_shared_object(&self, uuid: VhostUserSharedMsg) -> Result<File> {
self.lock().unwrap().get_shared_object(uuid)
}
fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)> {
self.lock().unwrap().get_inflight_fd(inflight)
}
fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()> {
self.lock().unwrap().set_inflight_fd(inflight, file)
}
fn get_max_mem_slots(&self) -> Result<u64> {
self.lock().unwrap().get_max_mem_slots()
}
fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()> {
self.lock().unwrap().add_mem_region(region, fd)
}
fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()> {
self.lock().unwrap().remove_mem_region(region)
}
fn set_device_state_fd(
&self,
direction: VhostTransferStateDirection,
phase: VhostTransferStatePhase,
fd: File,
) -> Result<Option<File>> {
self.lock()
.unwrap()
.set_device_state_fd(direction, phase, fd)
}
fn check_device_state(&self) -> Result<()> {
self.lock().unwrap().check_device_state()
}
fn get_shmem_config(&self) -> Result<VhostUserShMemConfig> {
self.lock().unwrap().get_shmem_config()
}
#[cfg(feature = "postcopy")]
fn postcopy_advice(&self) -> Result<File> {
self.lock().unwrap().postcopy_advice()
}
#[cfg(feature = "postcopy")]
fn postcopy_listen(&self) -> Result<()> {
self.lock().unwrap().postcopy_listen()
}
#[cfg(feature = "postcopy")]
fn postcopy_end(&self) -> Result<()> {
self.lock().unwrap().postcopy_end()
}
fn set_log_base(&self, log: &VhostUserLog, file: File) -> Result<()> {
self.lock().unwrap().set_log_base(log, file)
}
}
pub struct BackendReqHandler<S: VhostUserBackendReqHandler> {
main_sock: Endpoint<VhostUserMsgHeader<FrontendReq>>,
backend: Arc<S>,
virtio_features: u64,
acked_virtio_features: u64,
acked_protocol_features: u64,
reply_ack_enabled: bool,
error: Option<i32>,
}
impl<S: VhostUserBackendReqHandler> BackendReqHandler<S> {
pub(super) fn new(
main_sock: Endpoint<VhostUserMsgHeader<FrontendReq>>,
backend: Arc<S>,
) -> Self {
BackendReqHandler {
main_sock,
backend,
virtio_features: 0,
acked_virtio_features: 0,
acked_protocol_features: 0,
reply_ack_enabled: false,
error: None,
}
}
fn check_feature(&self, feat: VhostUserVirtioFeatures) -> Result<()> {
if self.acked_virtio_features & feat.bits() != 0 {
Ok(())
} else {
Err(Error::InactiveFeature(feat))
}
}
fn check_proto_feature(&self, feat: VhostUserProtocolFeatures) -> Result<()> {
if self.acked_protocol_features & feat.bits() != 0 {
Ok(())
} else {
Err(Error::InactiveOperation(feat))
}
}
pub fn from_stream(socket: UnixStream, backend: Arc<S>) -> Self {
Self::new(Endpoint::from_stream(socket), backend)
}
pub fn connect(path: &str, backend: Arc<S>) -> Result<Self> {
Ok(Self::new(
Endpoint::<VhostUserMsgHeader<FrontendReq>>::connect(path)?,
backend,
))
}
pub fn set_failed(&mut self, error: i32) {
self.error = Some(error);
}
pub fn handle_request(&mut self) -> Result<()> {
self.check_state()?;
let (hdr, files) = self.main_sock.recv_header()?;
self.check_attached_files(&hdr, &files)?;
let (size, buf) = match hdr.get_size() {
0 => (0, vec![0u8; 0]),
len => {
let (size2, rbuf) = self.main_sock.recv_data(len as usize)?;
if size2 != len as usize {
return Err(Error::InvalidMessage);
}
(size2, rbuf)
}
};
match hdr.get_code() {
Ok(FrontendReq::SET_OWNER) => {
self.check_request_size(&hdr, size, 0)?;
let res = self.backend.set_owner();
self.send_ack_message(&hdr, res)?;
}
Ok(FrontendReq::RESET_OWNER) => {
self.check_request_size(&hdr, size, 0)?;
let res = self.backend.reset_owner();
self.send_ack_message(&hdr, res)?;
}
Ok(FrontendReq::RESET_DEVICE) => {
self.check_proto_feature(VhostUserProtocolFeatures::RESET_DEVICE)?;
self.check_request_size(&hdr, size, 0)?;
let res = self.backend.reset_device();
self.send_ack_message(&hdr, res)?;
}
Ok(FrontendReq::GET_FEATURES) => {
self.check_request_size(&hdr, size, 0)?;
let features = self.backend.get_features()?;
let msg = VhostUserU64::new(features);
self.send_reply_message(&hdr, &msg)?;
self.virtio_features = features;
self.update_reply_ack_flag();
}
Ok(FrontendReq::SET_FEATURES) => {
let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
let res = self.backend.set_features(msg.value);
self.acked_virtio_features = msg.value;
self.update_reply_ack_flag();
self.send_ack_message(&hdr, res)?;
}
Ok(FrontendReq::SET_MEM_TABLE) => {
let res = self.set_mem_table(&hdr, size, &buf, files);
self.send_ack_message(&hdr, res)?;
}
Ok(FrontendReq::SET_VRING_NUM) => {
let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
let res = self.backend.set_vring_num(msg.index, msg.num);
self.send_ack_message(&hdr, res)?;
}
Ok(FrontendReq::SET_VRING_ADDR) => {
let msg = self.extract_request_body::<VhostUserVringAddr>(&hdr, size, &buf)?;
let flags = match VhostUserVringAddrFlags::from_bits(msg.flags) {
Some(val) => val,
None => return Err(Error::InvalidMessage),
};
let res = self.backend.set_vring_addr(
msg.index,
flags,
msg.descriptor,
msg.used,
msg.available,
msg.log,
);
self.send_ack_message(&hdr, res)?;
}
Ok(FrontendReq::SET_VRING_BASE) => {
let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
let res = self.backend.set_vring_base(msg.index, msg.num);
self.send_ack_message(&hdr, res)?;
}
Ok(FrontendReq::GET_VRING_BASE) => {
let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
let reply = self.backend.get_vring_base(msg.index)?;
self.send_reply_message(&hdr, &reply)?;
}
Ok(FrontendReq::SET_VRING_CALL) => {
self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
let (index, file) = self.handle_vring_fd_request(&buf, files)?;
let res = self.backend.set_vring_call(index, file);
self.send_ack_message(&hdr, res)?;
}
Ok(FrontendReq::SET_VRING_KICK) => {
self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
let (index, file) = self.handle_vring_fd_request(&buf, files)?;
let res = self.backend.set_vring_kick(index, file);
self.send_ack_message(&hdr, res)?;
}
Ok(FrontendReq::SET_VRING_ERR) => {
self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
let (index, file) = self.handle_vring_fd_request(&buf, files)?;
let res = self.backend.set_vring_err(index, file);
self.send_ack_message(&hdr, res)?;
}
Ok(FrontendReq::GET_PROTOCOL_FEATURES) => {
self.check_request_size(&hdr, size, 0)?;
let features =
self.backend.get_protocol_features()? | VhostUserProtocolFeatures::REPLY_ACK;
#[cfg(feature = "xen")]
let features = features | VhostUserProtocolFeatures::XEN_MMAP;
let msg = VhostUserU64::new(features.bits());
self.send_reply_message(&hdr, &msg)?;
self.update_reply_ack_flag();
}
Ok(FrontendReq::SET_PROTOCOL_FEATURES) => {
let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
let res = self.backend.set_protocol_features(msg.value);
self.acked_protocol_features = msg.value;
self.update_reply_ack_flag();
self.send_ack_message(&hdr, res)?;
#[cfg(feature = "xen")]
self.check_proto_feature(VhostUserProtocolFeatures::XEN_MMAP)?;
}
Ok(FrontendReq::GET_QUEUE_NUM) => {
self.check_proto_feature(VhostUserProtocolFeatures::MQ)?;
self.check_request_size(&hdr, size, 0)?;
let num = self.backend.get_queue_num()?;
let msg = VhostUserU64::new(num);
self.send_reply_message(&hdr, &msg)?;
}
Ok(FrontendReq::SET_VRING_ENABLE) => {
let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
self.check_feature(VhostUserVirtioFeatures::PROTOCOL_FEATURES)?;
let enable = match msg.num {
1 => true,
0 => false,
_ => return Err(Error::InvalidParam),
};
let res = self.backend.set_vring_enable(msg.index, enable);
self.send_ack_message(&hdr, res)?;
}
Ok(FrontendReq::GET_CONFIG) => {
self.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?;
self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
self.get_config(&hdr, &buf)?;
}
Ok(FrontendReq::SET_CONFIG) => {
self.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?;
self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
let res = self.set_config(size, &buf);
self.send_ack_message(&hdr, res)?;
}
Ok(FrontendReq::SET_BACKEND_REQ_FD) => {
self.check_proto_feature(VhostUserProtocolFeatures::BACKEND_REQ)?;
self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
let res = self.set_backend_req_fd(files);
self.send_ack_message(&hdr, res)?;
}
Ok(FrontendReq::GET_SHARED_OBJECT) => {
self.check_proto_feature(VhostUserProtocolFeatures::SHARED_OBJECT)?;
self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
let msg = self.extract_request_body::<VhostUserSharedMsg>(&hdr, size, &buf)?;
let reply_hdr = self.new_reply_header::<VhostUserEmpty>(&hdr, 0)?;
match self.backend.get_shared_object(msg) {
Ok(file) => {
self.main_sock.send_message(
&reply_hdr,
&VhostUserEmpty,
Some(&[file.as_raw_fd()]),
)?;
}
Err(_) => {
self.main_sock
.send_message(&reply_hdr, &VhostUserEmpty, None)?;
}
}
}
Ok(FrontendReq::GET_INFLIGHT_FD) => {
self.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?;
let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
let (inflight, file) = self.backend.get_inflight_fd(&msg)?;
let reply_hdr = self.new_reply_header::<VhostUserInflight>(&hdr, 0)?;
self.main_sock
.send_message(&reply_hdr, &inflight, Some(&[file.as_raw_fd()]))?;
}
Ok(FrontendReq::SET_INFLIGHT_FD) => {
self.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?;
let file = take_single_file(files).ok_or(Error::IncorrectFds)?;
let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
let res = self.backend.set_inflight_fd(&msg, file);
self.send_ack_message(&hdr, res)?;
}
Ok(FrontendReq::GPU_SET_SOCKET) => {
let res = self.set_gpu_socket(files);
self.send_ack_message(&hdr, res)?;
}
Ok(FrontendReq::GET_MAX_MEM_SLOTS) => {
self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?;
self.check_request_size(&hdr, size, 0)?;
let num = self.backend.get_max_mem_slots()?;
let msg = VhostUserU64::new(num);
self.send_reply_message(&hdr, &msg)?;
}
Ok(FrontendReq::ADD_MEM_REG) => {
self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?;
let mut files = files.ok_or(Error::InvalidParam)?;
if files.len() != 1 {
return Err(Error::InvalidParam);
}
let msg =
self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
let res = self.backend.add_mem_region(&msg, files.swap_remove(0));
self.send_ack_message(&hdr, res)?;
}
Ok(FrontendReq::REM_MEM_REG) => {
self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?;
let msg =
self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
let res = self.backend.remove_mem_region(&msg);
self.send_ack_message(&hdr, res)?;
}
Ok(FrontendReq::SET_DEVICE_STATE_FD) => {
let file = take_single_file(files).ok_or(Error::IncorrectFds)?;
let msg =
self.extract_request_body::<VhostUserTransferDeviceState>(&hdr, size, &buf)?;
let reply_hdr = self.new_reply_header::<VhostUserU64>(&hdr, 0)?;
let direction: VhostTransferStateDirection = msg
.direction
.try_into()
.map_err(|_| Error::InvalidMessage)?;
let phase: VhostTransferStatePhase =
msg.phase.try_into().map_err(|_| Error::InvalidMessage)?;
let res = self.backend.set_device_state_fd(direction, phase, file);
match res {
Ok(None) => {
let msg = VhostUserU64::new(0x100); self.main_sock.send_message(&reply_hdr, &msg, None)?;
}
Ok(Some(file)) => {
let msg = VhostUserU64::new(0);
self.main_sock
.send_message(&reply_hdr, &msg, Some(&[file.as_raw_fd()]))?;
}
Err(_) => {
let msg = VhostUserU64::new(0x101);
self.main_sock.send_message(&reply_hdr, &msg, None)?;
}
}
}
Ok(FrontendReq::CHECK_DEVICE_STATE) => {
let res = self.backend.check_device_state();
let msg = match res {
Ok(_) => VhostUserU64::new(0),
Err(_) => VhostUserU64::new(1),
};
self.send_reply_message(&hdr, &msg)?;
}
Ok(FrontendReq::GET_SHMEM_CONFIG) => {
self.check_proto_feature(VhostUserProtocolFeatures::SHMEM)?;
let msg = self.backend.get_shmem_config()?;
self.send_reply_message(&hdr, &msg)?;
}
#[cfg(feature = "postcopy")]
Ok(FrontendReq::POSTCOPY_ADVISE) => {
self.check_proto_feature(VhostUserProtocolFeatures::PAGEFAULT)?;
let reply_hdr = self.new_reply_header::<VhostUserEmpty>(&hdr, 0)?;
let res = self.backend.postcopy_advice();
match res {
Ok(uffd_file) => self.main_sock.send_message(
&reply_hdr,
&VhostUserEmpty,
Some(&[uffd_file.as_raw_fd()]),
)?,
Err(_) => self
.main_sock
.send_message(&reply_hdr, &VhostUserEmpty, None)?,
}
}
#[cfg(feature = "postcopy")]
Ok(FrontendReq::POSTCOPY_LISTEN) => {
self.check_proto_feature(VhostUserProtocolFeatures::PAGEFAULT)?;
let res = self.backend.postcopy_listen();
self.send_ack_message(&hdr, res)?;
}
#[cfg(feature = "postcopy")]
Ok(FrontendReq::POSTCOPY_END) => {
self.check_proto_feature(VhostUserProtocolFeatures::PAGEFAULT)?;
let res = self.backend.postcopy_end();
self.send_ack_message(&hdr, res)?;
}
Ok(FrontendReq::SET_LOG_BASE) => {
self.check_proto_feature(VhostUserProtocolFeatures::LOG_SHMFD)?;
let file = take_single_file(files).ok_or(Error::IncorrectFds)?;
let msg = self.extract_request_body::<VhostUserLog>(&hdr, size, &buf)?;
self.backend.set_log_base(&msg, file)?;
self.send_reply_message(&hdr, &msg)?;
}
_ => {
return Err(Error::InvalidMessage);
}
}
Ok(())
}
fn set_mem_table(
&mut self,
hdr: &VhostUserMsgHeader<FrontendReq>,
size: usize,
buf: &[u8],
files: Option<Vec<File>>,
) -> Result<()> {
self.check_request_size(hdr, size, hdr.get_size() as usize)?;
let hdrsize = mem::size_of::<VhostUserMemory>();
if size < hdrsize {
return Err(Error::InvalidMessage);
}
let msg = unsafe { &*(buf.as_ptr() as *const VhostUserMemory) };
if !msg.is_valid() {
return Err(Error::InvalidMessage);
}
if size != hdrsize + msg.num_regions as usize * mem::size_of::<VhostUserMemoryRegion>() {
return Err(Error::InvalidMessage);
}
let files = files.ok_or(Error::InvalidMessage)?;
if files.len() != msg.num_regions as usize {
return Err(Error::InvalidMessage);
}
let regions = unsafe {
slice::from_raw_parts(
buf.as_ptr().add(hdrsize) as *const VhostUserMemoryRegion,
msg.num_regions as usize,
)
};
for region in regions.iter() {
if !region.is_valid() {
return Err(Error::InvalidMessage);
}
}
self.backend.set_mem_table(regions, files)
}
fn get_config(&mut self, hdr: &VhostUserMsgHeader<FrontendReq>, buf: &[u8]) -> Result<()> {
let payload_offset = mem::size_of::<VhostUserConfig>();
if buf.len() > MAX_MSG_SIZE || buf.len() < payload_offset {
return Err(Error::InvalidMessage);
}
let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) };
if !msg.is_valid() {
return Err(Error::InvalidMessage);
}
if buf.len() - payload_offset != msg.size as usize {
return Err(Error::InvalidMessage);
}
let flags = match VhostUserConfigFlags::from_bits(msg.flags) {
Some(val) => val,
None => return Err(Error::InvalidMessage),
};
let res = self.backend.get_config(msg.offset, msg.size, flags);
match res {
Ok(ref buf) if buf.len() == msg.size as usize => {
let reply = VhostUserConfig::new(msg.offset, buf.len() as u32, flags);
self.send_reply_with_payload(hdr, &reply, buf.as_slice())?;
}
Ok(_) => {
let reply = VhostUserConfig::new(msg.offset, 0, flags);
self.send_reply_message(hdr, &reply)?;
}
Err(_) => {
let reply = VhostUserConfig::new(msg.offset, 0, flags);
self.send_reply_message(hdr, &reply)?;
}
}
Ok(())
}
fn set_config(&mut self, size: usize, buf: &[u8]) -> Result<()> {
if size > MAX_MSG_SIZE || size < mem::size_of::<VhostUserConfig>() {
return Err(Error::InvalidMessage);
}
let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) };
if !msg.is_valid() {
return Err(Error::InvalidMessage);
}
if size - mem::size_of::<VhostUserConfig>() != msg.size as usize {
return Err(Error::InvalidMessage);
}
let flags = VhostUserConfigFlags::from_bits(msg.flags).ok_or(Error::InvalidMessage)?;
self.backend
.set_config(msg.offset, &buf[mem::size_of::<VhostUserConfig>()..], flags)
}
fn set_backend_req_fd(&mut self, files: Option<Vec<File>>) -> Result<()> {
let file = take_single_file(files).ok_or(Error::InvalidMessage)?;
let sock = unsafe { UnixStream::from_raw_fd(file.into_raw_fd()) };
let backend = Backend::from_stream(sock);
self.backend.set_backend_req_fd(backend);
Ok(())
}
fn set_gpu_socket(&mut self, files: Option<Vec<File>>) -> Result<()> {
let file = take_single_file(files).ok_or(Error::InvalidMessage)?;
let sock = unsafe { UnixStream::from_raw_fd(file.into_raw_fd()) };
let gpu_backend = GpuBackend::from_stream(sock);
self.backend.set_gpu_socket(gpu_backend)
}
fn handle_vring_fd_request(
&mut self,
buf: &[u8],
files: Option<Vec<File>>,
) -> Result<(u8, Option<File>)> {
if buf.len() > MAX_MSG_SIZE || buf.len() < mem::size_of::<VhostUserU64>() {
return Err(Error::InvalidMessage);
}
let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserU64) };
if !msg.is_valid() {
return Err(Error::InvalidMessage);
}
let has_fd = (msg.value & 0x100u64) == 0;
let file = take_single_file(files);
if has_fd && file.is_none() || !has_fd && file.is_some() {
return Err(Error::InvalidMessage);
}
Ok((msg.value as u8, file))
}
fn check_state(&self) -> Result<()> {
match self.error {
Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))),
None => Ok(()),
}
}
fn check_request_size(
&self,
hdr: &VhostUserMsgHeader<FrontendReq>,
size: usize,
expected: usize,
) -> Result<()> {
if hdr.get_size() as usize != expected
|| hdr.is_reply()
|| hdr.get_version() != 0x1
|| size != expected
{
return Err(Error::InvalidMessage);
}
Ok(())
}
fn check_attached_files(
&self,
hdr: &VhostUserMsgHeader<FrontendReq>,
files: &Option<Vec<File>>,
) -> Result<()> {
match hdr.get_code() {
Ok(
FrontendReq::SET_MEM_TABLE
| FrontendReq::SET_VRING_CALL
| FrontendReq::SET_VRING_KICK
| FrontendReq::SET_VRING_ERR
| FrontendReq::SET_LOG_BASE
| FrontendReq::SET_LOG_FD
| FrontendReq::SET_BACKEND_REQ_FD
| FrontendReq::SET_INFLIGHT_FD
| FrontendReq::ADD_MEM_REG
| FrontendReq::SET_DEVICE_STATE_FD
| FrontendReq::GPU_SET_SOCKET,
) => Ok(()),
_ if files.is_some() => Err(Error::InvalidMessage),
_ => Ok(()),
}
}
fn extract_request_body<T: Sized + VhostUserMsgValidator>(
&self,
hdr: &VhostUserMsgHeader<FrontendReq>,
size: usize,
buf: &[u8],
) -> Result<T> {
self.check_request_size(hdr, size, mem::size_of::<T>())?;
let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) };
if !msg.is_valid() {
return Err(Error::InvalidMessage);
}
Ok(msg)
}
fn update_reply_ack_flag(&mut self) {
let vflag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
let pflag = VhostUserProtocolFeatures::REPLY_ACK;
self.reply_ack_enabled = (self.virtio_features & vflag) != 0
&& (self.acked_protocol_features & pflag.bits()) != 0;
}
fn new_reply_header<T: Sized>(
&self,
req: &VhostUserMsgHeader<FrontendReq>,
payload_size: usize,
) -> Result<VhostUserMsgHeader<FrontendReq>> {
if mem::size_of::<T>() > MAX_MSG_SIZE
|| payload_size > MAX_MSG_SIZE
|| mem::size_of::<T>() + payload_size > MAX_MSG_SIZE
{
return Err(Error::InvalidParam);
}
self.check_state()?;
Ok(VhostUserMsgHeader::new(
req.get_code()?,
VhostUserHeaderFlag::REPLY.bits(),
(mem::size_of::<T>() + payload_size) as u32,
))
}
fn send_ack_message(
&mut self,
req: &VhostUserMsgHeader<FrontendReq>,
res: Result<()>,
) -> Result<()> {
if self.reply_ack_enabled && req.is_need_reply() {
let hdr = self.new_reply_header::<VhostUserU64>(req, 0)?;
let val = match res {
Ok(_) => 0,
Err(_) => 1,
};
let msg = VhostUserU64::new(val);
self.main_sock.send_message(&hdr, &msg, None)?;
}
res
}
fn send_reply_message<T: ByteValued>(
&mut self,
req: &VhostUserMsgHeader<FrontendReq>,
msg: &T,
) -> Result<()> {
let hdr = self.new_reply_header::<T>(req, 0)?;
self.main_sock.send_message(&hdr, msg, None)?;
Ok(())
}
fn send_reply_with_payload<T: ByteValued>(
&mut self,
req: &VhostUserMsgHeader<FrontendReq>,
msg: &T,
payload: &[u8],
) -> Result<()> {
let hdr = self.new_reply_header::<T>(req, payload.len())?;
self.main_sock
.send_message_with_payload(&hdr, msg, payload, None)?;
Ok(())
}
}
impl<S: VhostUserBackendReqHandler> AsRawFd for BackendReqHandler<S> {
fn as_raw_fd(&self) -> RawFd {
self.main_sock.as_raw_fd()
}
}
#[cfg(test)]
mod tests {
use std::os::unix::io::AsRawFd;
use super::*;
use crate::vhost_user::dummy_backend::DummyBackendReqHandler;
#[test]
fn test_backend_req_handler_new() {
let (p1, _p2) = UnixStream::pair().unwrap();
let endpoint = Endpoint::<VhostUserMsgHeader<FrontendReq>>::from_stream(p1);
let backend = Arc::new(Mutex::new(DummyBackendReqHandler::new()));
let mut handler = BackendReqHandler::new(endpoint, backend);
handler.check_state().unwrap();
handler.set_failed(libc::EAGAIN);
handler.check_state().unwrap_err();
assert!(handler.as_raw_fd() >= 0);
}
fn send_get_shmem_config_request(
mut endpoint: Endpoint<VhostUserMsgHeader<FrontendReq>>,
) -> VhostUserShMemConfig {
let hdr = VhostUserMsgHeader::new(FrontendReq::GET_SHMEM_CONFIG, 0, 0);
endpoint.send_message(&hdr, &VhostUserEmpty, None).unwrap();
let (reply_hdr, reply_config, rfds) = endpoint.recv_body::<VhostUserShMemConfig>().unwrap();
assert_eq!(reply_hdr.get_code().unwrap(), FrontendReq::GET_SHMEM_CONFIG);
assert!(reply_hdr.is_reply());
assert!(rfds.is_none());
reply_config
}
fn create_handler_with_shmem(
backend: Arc<Mutex<DummyBackendReqHandler>>,
p1: UnixStream,
) -> BackendReqHandler<Mutex<DummyBackendReqHandler>> {
let mut handler = BackendReqHandler::new(
Endpoint::<VhostUserMsgHeader<FrontendReq>>::from_stream(p1),
backend,
);
handler.acked_protocol_features = VhostUserProtocolFeatures::SHMEM.bits();
handler
}
#[test]
fn test_get_shmem_config_multiple_regions() {
let memory_sizes = [
0x1000, 0x2000, 0x3000, 0x4000, 0x5000, 0x6000, 0x7000, 0x8000,
];
let config = VhostUserShMemConfig::new(8, &memory_sizes);
let (p1, p2) = UnixStream::pair().unwrap();
let mut dummy_backend = DummyBackendReqHandler::new();
dummy_backend.set_shmem_config(config);
let mut handler = create_handler_with_shmem(Arc::new(Mutex::new(dummy_backend)), p1);
let handle = std::thread::spawn(move || {
send_get_shmem_config_request(Endpoint::<VhostUserMsgHeader<FrontendReq>>::from_stream(
p2,
))
});
handler.handle_request().unwrap();
let reply_config = handle.join().unwrap();
assert_eq!(reply_config.nregions, 8);
for i in 0..8 {
assert_eq!(reply_config.memory_sizes[i], (i as u64 + 1) * 0x1000);
}
for i in 8..256 {
assert_eq!(reply_config.memory_sizes[i], 0);
}
}
#[test]
fn test_get_shmem_config_non_continuous_regions() {
let memory_sizes = [0x10000, 0, 0x20000, 0, 0, 0, 0, 0];
let config = VhostUserShMemConfig::new(2, &memory_sizes);
let (p1, p2) = UnixStream::pair().unwrap();
let mut dummy_backend = DummyBackendReqHandler::new();
dummy_backend.set_shmem_config(config);
let mut handler = create_handler_with_shmem(Arc::new(Mutex::new(dummy_backend)), p1);
let handle = std::thread::spawn(move || {
send_get_shmem_config_request(Endpoint::<VhostUserMsgHeader<FrontendReq>>::from_stream(
p2,
))
});
handler.handle_request().unwrap();
let reply_config = handle.join().unwrap();
assert_eq!(reply_config.nregions, 2);
assert_eq!(reply_config.memory_sizes[0], 0x10000);
assert_eq!(reply_config.memory_sizes[1], 0);
assert_eq!(reply_config.memory_sizes[2], 0x20000);
for i in 3..256 {
assert_eq!(reply_config.memory_sizes[i], 0);
}
}
#[test]
fn test_get_shmem_config_feature_not_negotiated() {
let (p1, p2) = UnixStream::pair().unwrap();
let backend = Arc::new(Mutex::new(DummyBackendReqHandler::new()));
let mut handler = BackendReqHandler::new(
Endpoint::<VhostUserMsgHeader<FrontendReq>>::from_stream(p1),
backend,
);
let mut frontend_endpoint = Endpoint::<VhostUserMsgHeader<FrontendReq>>::from_stream(p2);
std::thread::spawn(move || {
let hdr = VhostUserMsgHeader::new(FrontendReq::GET_SHMEM_CONFIG, 0, 0);
let _ = frontend_endpoint.send_message(&hdr, &VhostUserEmpty, None);
});
assert!(handler.handle_request().is_err());
}
}