use std::{
fs::File,
io,
io::Read,
num::Wrapping,
ops::Deref,
os::unix::{
net::{UnixListener, UnixStream},
prelude::{AsRawFd, FromRawFd, RawFd},
},
sync::{Arc, RwLock},
};
use futures::executor::{ThreadPool, ThreadPoolBuilder};
use log::warn;
use vhost_user_backend::{VringEpollHandler, VringRwLock, VringT};
use virtio_queue::QueueOwnedT;
use virtio_vsock::packet::{VsockPacket, PKT_HEADER_SIZE};
use vm_memory::{GuestAddressSpace, GuestMemoryAtomic, GuestMemoryMmap};
use vmm_sys_util::{
epoll::EventSet,
eventfd::{EventFd, EFD_NONBLOCK},
};
use crate::{
rxops::*,
thread_backend::*,
vhu_vsock::{
CidMap, ConnMapKey, Error, Result, VhostUserVsockBackend, BACKEND_EVENT, SIBLING_VM_EVENT,
VSOCK_HOST_CID,
},
vsock_conn::*,
};
type ArcVhostBknd = Arc<VhostUserVsockBackend>;
enum RxQueueType {
Standard,
RawPkts,
}
pub(crate) struct VhostUserVsockThread {
pub mem: Option<GuestMemoryAtomic<GuestMemoryMmap>>,
pub event_idx: bool,
host_sock: RawFd,
host_sock_path: String,
host_listener: UnixListener,
vring_worker: Option<Arc<VringEpollHandler<ArcVhostBknd, VringRwLock, ()>>>,
epoll_file: File,
pub thread_backend: VsockThreadBackend,
guest_cid: u64,
pool: ThreadPool,
local_port: Wrapping<u32>,
tx_buffer_size: u32,
pub sibling_event_fd: EventFd,
last_processed: RxQueueType,
}
impl VhostUserVsockThread {
pub fn new(
uds_path: String,
guest_cid: u64,
tx_buffer_size: u32,
cid_map: Arc<RwLock<CidMap>>,
) -> Result<Self> {
let _ = std::fs::remove_file(uds_path.clone());
let host_sock = UnixListener::bind(&uds_path)
.and_then(|sock| sock.set_nonblocking(true).map(|_| sock))
.map_err(Error::UnixBind)?;
let epoll_fd = epoll::create(true).map_err(Error::EpollFdCreate)?;
let epoll_file = unsafe { File::from_raw_fd(epoll_fd) };
let host_raw_fd = host_sock.as_raw_fd();
let sibling_event_fd = EventFd::new(EFD_NONBLOCK).map_err(Error::EventFdCreate)?;
let thread_backend = VsockThreadBackend::new(
uds_path.clone(),
epoll_fd,
guest_cid,
tx_buffer_size,
cid_map.clone(),
);
cid_map.write().unwrap().insert(
guest_cid,
(
thread_backend.raw_pkts_queue.clone(),
sibling_event_fd.try_clone().unwrap(),
),
);
let thread = VhostUserVsockThread {
mem: None,
event_idx: false,
host_sock: host_sock.as_raw_fd(),
host_sock_path: uds_path,
host_listener: host_sock,
vring_worker: None,
epoll_file,
thread_backend,
guest_cid,
pool: ThreadPoolBuilder::new()
.pool_size(1)
.create()
.map_err(Error::CreateThreadPool)?,
local_port: Wrapping(0),
tx_buffer_size,
sibling_event_fd,
last_processed: RxQueueType::Standard,
};
VhostUserVsockThread::epoll_register(epoll_fd, host_raw_fd, epoll::Events::EPOLLIN)?;
Ok(thread)
}
pub fn epoll_register(epoll_fd: RawFd, fd: RawFd, evset: epoll::Events) -> Result<()> {
epoll::ctl(
epoll_fd,
epoll::ControlOptions::EPOLL_CTL_ADD,
fd,
epoll::Event::new(evset, fd as u64),
)
.map_err(Error::EpollAdd)?;
Ok(())
}
pub fn epoll_unregister(epoll_fd: RawFd, fd: RawFd) -> Result<()> {
epoll::ctl(
epoll_fd,
epoll::ControlOptions::EPOLL_CTL_DEL,
fd,
epoll::Event::new(epoll::Events::empty(), 0),
)
.map_err(Error::EpollRemove)?;
Ok(())
}
pub fn epoll_modify(epoll_fd: RawFd, fd: RawFd, evset: epoll::Events) -> Result<()> {
epoll::ctl(
epoll_fd,
epoll::ControlOptions::EPOLL_CTL_MOD,
fd,
epoll::Event::new(evset, fd as u64),
)
.map_err(Error::EpollModify)?;
Ok(())
}
fn get_epoll_fd(&self) -> RawFd {
self.epoll_file.as_raw_fd()
}
pub fn set_vring_worker(
&mut self,
vring_worker: Option<Arc<VringEpollHandler<ArcVhostBknd, VringRwLock, ()>>>,
) {
self.vring_worker = vring_worker;
self.vring_worker
.as_ref()
.unwrap()
.register_listener(self.get_epoll_fd(), EventSet::IN, u64::from(BACKEND_EVENT))
.unwrap();
self.vring_worker
.as_ref()
.unwrap()
.register_listener(
self.sibling_event_fd.as_raw_fd(),
EventSet::IN,
u64::from(SIBLING_VM_EVENT),
)
.unwrap();
}
pub fn process_backend_evt(&mut self, _evset: EventSet) {
let mut epoll_events = vec![epoll::Event::new(epoll::Events::empty(), 0); 32];
'epoll: loop {
match epoll::wait(self.epoll_file.as_raw_fd(), 0, epoll_events.as_mut_slice()) {
Ok(ev_cnt) => {
for evt in epoll_events.iter().take(ev_cnt) {
self.handle_event(
evt.data as RawFd,
epoll::Events::from_bits(evt.events).unwrap(),
);
}
}
Err(e) => {
if e.kind() == io::ErrorKind::Interrupted {
continue;
}
warn!("failed to consume new epoll event");
}
}
break 'epoll;
}
}
fn handle_event(&mut self, fd: RawFd, evset: epoll::Events) {
if fd == self.host_sock {
let conn = self.host_listener.accept().map_err(Error::UnixAccept);
if self.mem.is_some() {
conn.and_then(|(stream, _)| {
stream
.set_nonblocking(true)
.map(|_| stream)
.map_err(Error::UnixAccept)
})
.and_then(|stream| self.add_stream_listener(stream))
.unwrap_or_else(|err| {
warn!("Unable to accept new local connection: {:?}", err);
});
} else {
conn.map(drop).unwrap_or_else(|err| {
warn!("Error closing an incoming connection: {:?}", err);
});
}
} else {
if let std::collections::hash_map::Entry::Vacant(_) =
self.thread_backend.listener_map.entry(fd)
{
if evset.bits() != epoll::Events::EPOLLIN.bits() {
return;
}
let mut unix_stream = match self.thread_backend.stream_map.remove(&fd) {
Some(uds) => uds,
None => {
warn!("Error while searching fd in the stream map");
return;
}
};
let peer_port = match Self::read_local_stream_port(&mut unix_stream) {
Ok(port) => port,
Err(err) => {
warn!("Error while parsing \"connect PORT\n\" command: {:?}", err);
return;
}
};
let local_port = match self.allocate_local_port() {
Ok(lp) => lp,
Err(err) => {
warn!("Error while allocating local port: {:?}", err);
return;
}
};
self.thread_backend
.listener_map
.insert(fd, ConnMapKey::new(local_port, peer_port));
let conn_map_key = ConnMapKey::new(local_port, peer_port);
let mut new_conn = VsockConnection::new_local_init(
unix_stream,
VSOCK_HOST_CID,
local_port,
self.guest_cid,
peer_port,
self.get_epoll_fd(),
self.tx_buffer_size,
);
new_conn.rx_queue.enqueue(RxOps::Request);
new_conn.set_peer_port(peer_port);
self.thread_backend.conn_map.insert(conn_map_key, new_conn);
self.thread_backend
.backend_rxq
.push_back(ConnMapKey::new(local_port, peer_port));
Self::epoll_modify(
self.get_epoll_fd(),
fd,
epoll::Events::EPOLLIN | epoll::Events::EPOLLOUT,
)
.unwrap();
} else {
let key = self.thread_backend.listener_map.get(&fd).unwrap();
let conn = self.thread_backend.conn_map.get_mut(key).unwrap();
if evset.bits() == epoll::Events::EPOLLOUT.bits() {
match conn.tx_buf.flush_to(&mut conn.stream) {
Ok(cnt) => {
if cnt > 0 {
conn.fwd_cnt += Wrapping(cnt as u32);
conn.rx_queue.enqueue(RxOps::CreditUpdate);
}
self.thread_backend
.backend_rxq
.push_back(ConnMapKey::new(conn.local_port, conn.peer_port));
}
Err(e) => {
dbg!("Error: {:?}", e);
}
}
return;
}
Self::epoll_unregister(self.epoll_file.as_raw_fd(), fd).unwrap();
conn.rx_queue.enqueue(RxOps::Rw);
self.thread_backend
.backend_rxq
.push_back(ConnMapKey::new(conn.local_port, conn.peer_port));
}
}
}
fn allocate_local_port(&mut self) -> Result<u32> {
let mut alloc_local_port = self.local_port.0;
loop {
if !self
.thread_backend
.local_port_set
.contains(&alloc_local_port)
{
self.local_port = Wrapping(alloc_local_port + 1);
self.thread_backend.local_port_set.insert(alloc_local_port);
return Ok(alloc_local_port);
} else {
if alloc_local_port == self.local_port.0 {
return Err(Error::NoFreeLocalPort);
}
alloc_local_port += 1;
}
}
}
fn read_local_stream_port(stream: &mut UnixStream) -> Result<u32> {
let mut buf = [0u8; 32];
const MIN_READ_LEN: usize = 10;
stream
.read_exact(&mut buf[..MIN_READ_LEN])
.map_err(Error::UnixRead)?;
let mut read_len = MIN_READ_LEN;
while buf[read_len - 1] != b'\n' && read_len < buf.len() {
stream
.read_exact(&mut buf[read_len..read_len + 1])
.map_err(Error::UnixRead)?;
read_len += 1;
}
let mut word_iter = std::str::from_utf8(&buf[..read_len])
.map_err(Error::ConvertFromUtf8)?
.split_whitespace();
word_iter
.next()
.ok_or(Error::InvalidPortRequest)
.and_then(|word| {
if word.to_lowercase() == "connect" {
Ok(())
} else {
Err(Error::InvalidPortRequest)
}
})
.and_then(|_| word_iter.next().ok_or(Error::InvalidPortRequest))
.and_then(|word| word.parse::<u32>().map_err(Error::ParseInteger))
.map_err(|e| Error::ReadStreamPort(Box::new(e)))
}
fn add_stream_listener(&mut self, stream: UnixStream) -> Result<()> {
let stream_fd = stream.as_raw_fd();
self.thread_backend.stream_map.insert(stream_fd, stream);
VhostUserVsockThread::epoll_register(
self.get_epoll_fd(),
stream_fd,
epoll::Events::EPOLLIN,
)?;
Ok(())
}
fn process_rx_queue(
&mut self,
vring: &VringRwLock,
rx_queue_type: RxQueueType,
) -> Result<bool> {
let mut used_any = false;
let atomic_mem = match &self.mem {
Some(m) => m,
None => return Err(Error::NoMemoryConfigured),
};
let mut vring_mut = vring.get_mut();
let queue = vring_mut.get_queue_mut();
while let Some(mut avail_desc) = queue
.iter(atomic_mem.memory())
.map_err(|_| Error::IterateQueue)?
.next()
{
used_any = true;
let mem = atomic_mem.clone().memory();
let head_idx = avail_desc.head_index();
let used_len = match VsockPacket::from_rx_virtq_chain(
mem.deref(),
&mut avail_desc,
self.tx_buffer_size,
) {
Ok(mut pkt) => {
let recv_result = match rx_queue_type {
RxQueueType::Standard => self.thread_backend.recv_pkt(&mut pkt),
RxQueueType::RawPkts => self.thread_backend.recv_raw_pkt(&mut pkt),
};
if recv_result.is_ok() {
PKT_HEADER_SIZE + pkt.len() as usize
} else {
queue.iter(mem).unwrap().go_to_previous_position();
break;
}
}
Err(e) => {
warn!("vsock: RX queue error: {:?}", e);
0
}
};
let vring = vring.clone();
let event_idx = self.event_idx;
self.pool.spawn_ok(async move {
if event_idx {
if vring.add_used(head_idx, used_len as u32).is_err() {
warn!("Could not return used descriptors to ring");
}
match vring.needs_notification() {
Err(_) => {
warn!("Could not check if queue needs to be notified");
vring.signal_used_queue().unwrap();
}
Ok(needs_notification) => {
if needs_notification {
vring.signal_used_queue().unwrap();
}
}
}
} else {
if vring.add_used(head_idx, used_len as u32).is_err() {
warn!("Could not return used descriptors to ring");
}
vring.signal_used_queue().unwrap();
}
});
match rx_queue_type {
RxQueueType::Standard => {
if !self.thread_backend.pending_rx() {
break;
}
}
RxQueueType::RawPkts => {
if !self.thread_backend.pending_raw_pkts() {
break;
}
}
}
}
Ok(used_any)
}
fn process_unix_sockets(&mut self, vring: &VringRwLock, event_idx: bool) -> Result<bool> {
if event_idx {
loop {
if !self.thread_backend.pending_rx() {
break;
}
vring.disable_notification().unwrap();
self.process_rx_queue(vring, RxQueueType::Standard)?;
if !vring.enable_notification().unwrap() {
break;
}
}
} else {
self.process_rx_queue(vring, RxQueueType::Standard)?;
}
Ok(false)
}
pub fn process_raw_pkts(&mut self, vring: &VringRwLock, event_idx: bool) -> Result<bool> {
if event_idx {
loop {
if !self.thread_backend.pending_raw_pkts() {
break;
}
vring.disable_notification().unwrap();
self.process_rx_queue(vring, RxQueueType::RawPkts)?;
if !vring.enable_notification().unwrap() {
break;
}
}
} else {
self.process_rx_queue(vring, RxQueueType::RawPkts)?;
}
Ok(false)
}
pub fn process_rx(&mut self, vring: &VringRwLock, event_idx: bool) -> Result<bool> {
match self.last_processed {
RxQueueType::Standard => {
if self.thread_backend.pending_raw_pkts() {
self.process_raw_pkts(vring, event_idx)?;
self.last_processed = RxQueueType::RawPkts;
}
if self.thread_backend.pending_rx() {
self.process_unix_sockets(vring, event_idx)?;
}
}
RxQueueType::RawPkts => {
if self.thread_backend.pending_rx() {
self.process_unix_sockets(vring, event_idx)?;
self.last_processed = RxQueueType::Standard;
}
if self.thread_backend.pending_raw_pkts() {
self.process_raw_pkts(vring, event_idx)?;
}
}
}
Ok(false)
}
fn process_tx_queue(&mut self, vring: &VringRwLock) -> Result<bool> {
let mut used_any = false;
let atomic_mem = match &self.mem {
Some(m) => m,
None => return Err(Error::NoMemoryConfigured),
};
while let Some(mut avail_desc) = vring
.get_mut()
.get_queue_mut()
.iter(atomic_mem.memory())
.map_err(|_| Error::IterateQueue)?
.next()
{
used_any = true;
let mem = atomic_mem.clone().memory();
let head_idx = avail_desc.head_index();
let pkt = match VsockPacket::from_tx_virtq_chain(
mem.deref(),
&mut avail_desc,
self.tx_buffer_size,
) {
Ok(pkt) => pkt,
Err(e) => {
dbg!("vsock: error reading TX packet: {:?}", e);
continue;
}
};
if self.thread_backend.send_pkt(&pkt).is_err() {
vring
.get_mut()
.get_queue_mut()
.iter(mem)
.unwrap()
.go_to_previous_position();
break;
}
let used_len = 0;
let vring = vring.clone();
let event_idx = self.event_idx;
self.pool.spawn_ok(async move {
if event_idx {
if vring.add_used(head_idx, used_len as u32).is_err() {
warn!("Could not return used descriptors to ring");
}
match vring.needs_notification() {
Err(_) => {
warn!("Could not check if queue needs to be notified");
vring.signal_used_queue().unwrap();
}
Ok(needs_notification) => {
if needs_notification {
vring.signal_used_queue().unwrap();
}
}
}
} else {
if vring.add_used(head_idx, used_len as u32).is_err() {
warn!("Could not return used descriptors to ring");
}
vring.signal_used_queue().unwrap();
}
});
}
Ok(used_any)
}
pub fn process_tx(&mut self, vring_lock: &VringRwLock, event_idx: bool) -> Result<bool> {
if event_idx {
loop {
vring_lock.disable_notification().unwrap();
self.process_tx_queue(vring_lock)?;
if !vring_lock.enable_notification().unwrap() {
break;
}
}
} else {
self.process_tx_queue(vring_lock)?;
}
Ok(false)
}
}
impl Drop for VhostUserVsockThread {
fn drop(&mut self) {
let _ = std::fs::remove_file(&self.host_sock_path);
self.thread_backend
.cid_map
.write()
.unwrap()
.remove(&self.guest_cid);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use tempfile::tempdir;
use vm_memory::GuestAddress;
use vmm_sys_util::eventfd::EventFd;
const CONN_TX_BUF_SIZE: u32 = 64 * 1024;
impl VhostUserVsockThread {
fn get_epoll_file(&self) -> &File {
&self.epoll_file
}
}
#[test]
fn test_vsock_thread() {
let cid_map: Arc<RwLock<CidMap>> = Arc::new(RwLock::new(HashMap::new()));
let test_dir = tempdir().expect("Could not create a temp test directory.");
let t = VhostUserVsockThread::new(
test_dir
.path()
.join("test_vsock_thread.vsock")
.display()
.to_string(),
3,
CONN_TX_BUF_SIZE,
cid_map,
);
assert!(t.is_ok());
let mut t = t.unwrap();
let epoll_fd = t.get_epoll_file().as_raw_fd();
let mem = GuestMemoryAtomic::new(
GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10000)]).unwrap(),
);
t.mem = Some(mem.clone());
let dummy_fd = EventFd::new(0).unwrap();
assert!(VhostUserVsockThread::epoll_register(
epoll_fd,
dummy_fd.as_raw_fd(),
epoll::Events::EPOLLOUT
)
.is_ok());
assert!(VhostUserVsockThread::epoll_modify(
epoll_fd,
dummy_fd.as_raw_fd(),
epoll::Events::EPOLLIN
)
.is_ok());
assert!(VhostUserVsockThread::epoll_unregister(epoll_fd, dummy_fd.as_raw_fd()).is_ok());
assert!(VhostUserVsockThread::epoll_register(
epoll_fd,
dummy_fd.as_raw_fd(),
epoll::Events::EPOLLIN
)
.is_ok());
let vring = VringRwLock::new(mem, 0x1000).unwrap();
vring.set_queue_info(0x100, 0x200, 0x300).unwrap();
vring.set_queue_ready(true);
assert!(t.process_tx(&vring, false).is_ok());
assert!(t.process_tx(&vring, true).is_ok());
t.thread_backend
.backend_rxq
.push_back(ConnMapKey::new(0, 0));
assert!(t.process_rx(&vring, false).is_ok());
assert!(t.process_rx(&vring, true).is_ok());
dummy_fd.write(1).unwrap();
t.process_backend_evt(EventSet::empty());
test_dir.close().unwrap();
}
#[test]
fn test_vsock_thread_failures() {
let cid_map: Arc<RwLock<CidMap>> = Arc::new(RwLock::new(HashMap::new()));
let test_dir = tempdir().expect("Could not create a temp test directory.");
let t = VhostUserVsockThread::new(
"/sys/not_allowed.vsock".to_string(),
3,
CONN_TX_BUF_SIZE,
cid_map.clone(),
);
assert!(t.is_err());
let vsock_socket_path = test_dir
.path()
.join("test_vsock_thread_failures.vsock")
.display()
.to_string();
let mut t =
VhostUserVsockThread::new(vsock_socket_path, 3, CONN_TX_BUF_SIZE, cid_map).unwrap();
assert!(VhostUserVsockThread::epoll_register(-1, -1, epoll::Events::EPOLLIN).is_err());
assert!(VhostUserVsockThread::epoll_modify(-1, -1, epoll::Events::EPOLLIN).is_err());
assert!(VhostUserVsockThread::epoll_unregister(-1, -1).is_err());
let mem = GuestMemoryAtomic::new(
GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10000)]).unwrap(),
);
let vring = VringRwLock::new(mem, 0x1000).unwrap();
assert!(t.process_tx(&vring, false).is_err());
assert!(t.process_tx(&vring, true).is_err());
t.thread_backend
.backend_rxq
.push_back(ConnMapKey::new(0, 0));
assert!(t.process_rx(&vring, false).is_err());
assert!(t.process_rx(&vring, true).is_err());
test_dir.close().unwrap();
}
}