use std::{
collections::{HashMap, HashSet, VecDeque},
os::unix::{
net::UnixStream,
prelude::{AsRawFd, RawFd},
},
sync::{Arc, RwLock},
};
use log::{info, warn};
use virtio_vsock::packet::{VsockPacket, PKT_HEADER_SIZE};
use vm_memory::bitmap::BitmapSlice;
use crate::{
rxops::*,
vhu_vsock::{
CidMap, ConnMapKey, Error, Result, VSOCK_HOST_CID, VSOCK_OP_REQUEST, VSOCK_OP_RST,
VSOCK_TYPE_STREAM,
},
vhu_vsock_thread::VhostUserVsockThread,
vsock_conn::*,
};
pub(crate) type RawPktsQ = VecDeque<RawVsockPacket>;
pub(crate) struct RawVsockPacket {
pub header: [u8; PKT_HEADER_SIZE],
pub data: Vec<u8>,
}
impl RawVsockPacket {
fn from_vsock_packet<B: BitmapSlice>(pkt: &VsockPacket<B>) -> Result<Self> {
let mut raw_pkt = Self {
header: [0; PKT_HEADER_SIZE],
data: vec![0; pkt.len() as usize],
};
pkt.header_slice().copy_to(&mut raw_pkt.header);
if !pkt.is_empty() {
pkt.data_slice()
.ok_or(Error::PktBufMissing)?
.copy_to(raw_pkt.data.as_mut());
}
Ok(raw_pkt)
}
}
pub(crate) struct VsockThreadBackend {
pub listener_map: HashMap<RawFd, ConnMapKey>,
pub conn_map: HashMap<ConnMapKey, VsockConnection<UnixStream>>,
pub backend_rxq: VecDeque<ConnMapKey>,
pub stream_map: HashMap<i32, UnixStream>,
host_socket_path: String,
epoll_fd: i32,
guest_cid: u64,
pub local_port_set: HashSet<u32>,
tx_buffer_size: u32,
pub cid_map: Arc<RwLock<CidMap>>,
pub raw_pkts_queue: Arc<RwLock<RawPktsQ>>,
}
impl VsockThreadBackend {
pub fn new(
host_socket_path: String,
epoll_fd: i32,
guest_cid: u64,
tx_buffer_size: u32,
cid_map: Arc<RwLock<CidMap>>,
) -> Self {
Self {
listener_map: HashMap::new(),
conn_map: HashMap::new(),
backend_rxq: VecDeque::new(),
stream_map: HashMap::new(),
host_socket_path,
epoll_fd,
guest_cid,
local_port_set: HashSet::new(),
tx_buffer_size,
cid_map,
raw_pkts_queue: Arc::new(RwLock::new(VecDeque::new())),
}
}
pub fn pending_rx(&self) -> bool {
!self.backend_rxq.is_empty()
}
pub fn pending_raw_pkts(&self) -> bool {
!self.raw_pkts_queue.read().unwrap().is_empty()
}
pub fn recv_pkt<B: BitmapSlice>(&mut self, pkt: &mut VsockPacket<B>) -> Result<()> {
let key = self.backend_rxq.pop_front().ok_or(Error::EmptyBackendRxQ)?;
let conn = match self.conn_map.get_mut(&key) {
Some(conn) => conn,
None => {
return Ok(());
}
};
if conn.rx_queue.peek() == Some(RxOps::Reset) {
let conn = self.conn_map.remove(&key).unwrap();
self.listener_map.remove(&conn.stream.as_raw_fd());
self.stream_map.remove(&conn.stream.as_raw_fd());
self.local_port_set.remove(&conn.local_port);
VhostUserVsockThread::epoll_unregister(conn.epoll_fd, conn.stream.as_raw_fd())
.unwrap_or_else(|err| {
warn!(
"Could not remove epoll listener for fd {:?}: {:?}",
conn.stream.as_raw_fd(),
err
)
});
pkt.set_op(VSOCK_OP_RST)
.set_src_cid(VSOCK_HOST_CID)
.set_dst_cid(conn.guest_cid)
.set_src_port(conn.local_port)
.set_dst_port(conn.peer_port)
.set_len(0)
.set_type(VSOCK_TYPE_STREAM)
.set_flags(0)
.set_buf_alloc(0)
.set_fwd_cnt(0);
return Ok(());
}
conn.recv_pkt(pkt)?;
Ok(())
}
pub fn send_pkt<B: BitmapSlice>(&mut self, pkt: &VsockPacket<B>) -> Result<()> {
if pkt.src_cid() != self.guest_cid {
warn!(
"vsock: dropping packet with inconsistent src_cid: {:?} from guest configured with CID: {:?}",
pkt.src_cid(), self.guest_cid
);
return Ok(());
}
let dst_cid = pkt.dst_cid();
if dst_cid != VSOCK_HOST_CID {
let cid_map = self.cid_map.read().unwrap();
if cid_map.contains_key(&dst_cid) {
let (sibling_raw_pkts_queue, sibling_event_fd) = cid_map.get(&dst_cid).unwrap();
sibling_raw_pkts_queue
.write()
.unwrap()
.push_back(RawVsockPacket::from_vsock_packet(pkt)?);
let _ = sibling_event_fd.write(1);
} else {
warn!("vsock: dropping packet for unknown cid: {:?}", dst_cid);
}
return Ok(());
}
if pkt.type_() != VSOCK_TYPE_STREAM {
info!("vsock: dropping packet of unknown type");
return Ok(());
}
let key = ConnMapKey::new(pkt.dst_port(), pkt.src_port());
if !self.conn_map.contains_key(&key) {
if pkt.op() == VSOCK_OP_REQUEST {
self.handle_new_guest_conn(pkt);
} else {
}
return Ok(());
}
if pkt.op() == VSOCK_OP_RST {
let conn = self.conn_map.get(&key).unwrap();
if conn.rx_queue.contains(RxOps::Reset.bitmask()) {
return Ok(());
}
let conn = self.conn_map.remove(&key).unwrap();
self.listener_map.remove(&conn.stream.as_raw_fd());
self.stream_map.remove(&conn.stream.as_raw_fd());
self.local_port_set.remove(&conn.local_port);
VhostUserVsockThread::epoll_unregister(conn.epoll_fd, conn.stream.as_raw_fd())
.unwrap_or_else(|err| {
warn!(
"Could not remove epoll listener for fd {:?}: {:?}",
conn.stream.as_raw_fd(),
err
)
});
return Ok(());
}
let conn = self.conn_map.get_mut(&key).unwrap();
conn.send_pkt(pkt)?;
if conn.rx_queue.pending_rx() {
self.backend_rxq.push_back(key);
}
Ok(())
}
pub fn recv_raw_pkt<B: BitmapSlice>(&mut self, pkt: &mut VsockPacket<B>) -> Result<()> {
let raw_vsock_pkt = self
.raw_pkts_queue
.write()
.unwrap()
.pop_front()
.ok_or(Error::EmptyRawPktsQueue)?;
pkt.set_header_from_raw(&raw_vsock_pkt.header).unwrap();
if !raw_vsock_pkt.data.is_empty() {
let buf = pkt.data_slice().ok_or(Error::PktBufMissing)?;
buf.copy_from(&raw_vsock_pkt.data);
}
Ok(())
}
fn handle_new_guest_conn<B: BitmapSlice>(&mut self, pkt: &VsockPacket<B>) {
let port_path = format!("{}_{}", self.host_socket_path, pkt.dst_port());
UnixStream::connect(port_path)
.and_then(|stream| stream.set_nonblocking(true).map(|_| stream))
.map_err(Error::UnixConnect)
.and_then(|stream| self.add_new_guest_conn(stream, pkt))
.unwrap_or_else(|_| self.enq_rst());
}
fn add_new_guest_conn<B: BitmapSlice>(
&mut self,
stream: UnixStream,
pkt: &VsockPacket<B>,
) -> Result<()> {
let stream_fd = stream.as_raw_fd();
self.listener_map
.insert(stream_fd, ConnMapKey::new(pkt.dst_port(), pkt.src_port()));
let conn = VsockConnection::new_peer_init(
stream.try_clone().map_err(Error::UnixConnect)?,
pkt.dst_cid(),
pkt.dst_port(),
pkt.src_cid(),
pkt.src_port(),
self.epoll_fd,
pkt.buf_alloc(),
self.tx_buffer_size,
);
self.conn_map
.insert(ConnMapKey::new(pkt.dst_port(), pkt.src_port()), conn);
self.backend_rxq
.push_back(ConnMapKey::new(pkt.dst_port(), pkt.src_port()));
self.stream_map.insert(stream_fd, stream);
self.local_port_set.insert(pkt.dst_port());
VhostUserVsockThread::epoll_register(
self.epoll_fd,
stream_fd,
epoll::Events::EPOLLIN | epoll::Events::EPOLLOUT,
)?;
Ok(())
}
fn enq_rst(&mut self) {
dbg!("New guest conn error: Enqueue RST");
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::vhu_vsock::{VhostUserVsockBackend, VsockConfig, VSOCK_OP_RW};
use std::os::unix::net::UnixListener;
use tempfile::tempdir;
use virtio_vsock::packet::{VsockPacket, PKT_HEADER_SIZE};
const DATA_LEN: usize = 16;
const CONN_TX_BUF_SIZE: u32 = 64 * 1024;
#[test]
fn test_vsock_thread_backend() {
const CID: u64 = 3;
const VSOCK_PEER_PORT: u32 = 1234;
let test_dir = tempdir().expect("Could not create a temp test directory.");
let vsock_socket_path = test_dir.path().join("test_vsock_thread_backend.vsock");
let vsock_peer_path = test_dir.path().join("test_vsock_thread_backend.vsock_1234");
let _listener = UnixListener::bind(&vsock_peer_path).unwrap();
let epoll_fd = epoll::create(false).unwrap();
let cid_map: Arc<RwLock<CidMap>> = Arc::new(RwLock::new(HashMap::new()));
let mut vtp = VsockThreadBackend::new(
vsock_socket_path.display().to_string(),
epoll_fd,
CID,
CONN_TX_BUF_SIZE,
cid_map,
);
assert!(!vtp.pending_rx());
let mut pkt_raw = [0u8; PKT_HEADER_SIZE + DATA_LEN];
let (hdr_raw, data_raw) = pkt_raw.split_at_mut(PKT_HEADER_SIZE);
let mut packet = unsafe { VsockPacket::new(hdr_raw, Some(data_raw)).unwrap() };
assert_eq!(
vtp.recv_pkt(&mut packet).unwrap_err().to_string(),
Error::EmptyBackendRxQ.to_string()
);
assert!(vtp.send_pkt(&packet).is_ok());
packet.set_type(VSOCK_TYPE_STREAM);
assert!(vtp.send_pkt(&packet).is_ok());
packet.set_src_cid(CID);
packet.set_dst_cid(VSOCK_HOST_CID);
packet.set_dst_port(VSOCK_PEER_PORT);
assert!(vtp.send_pkt(&packet).is_ok());
packet.set_op(VSOCK_OP_REQUEST);
assert!(vtp.send_pkt(&packet).is_ok());
packet.set_op(VSOCK_OP_RW);
assert!(vtp.send_pkt(&packet).is_ok());
packet.set_op(VSOCK_OP_RST);
assert!(vtp.send_pkt(&packet).is_ok());
assert!(vtp.recv_pkt(&mut packet).is_ok());
let _ = std::fs::remove_file(&vsock_peer_path);
let _ = std::fs::remove_file(&vsock_socket_path);
test_dir.close().unwrap();
}
#[test]
fn test_vsock_thread_backend_sibling_vms() {
const CID: u64 = 3;
const SIBLING_CID: u64 = 4;
const SIBLING_LISTENING_PORT: u32 = 1234;
let test_dir = tempdir().expect("Could not create a temp test directory.");
let vsock_socket_path = test_dir
.path()
.join("test_vsock_thread_backend.vsock")
.display()
.to_string();
let sibling_vhost_socket_path = test_dir
.path()
.join("test_vsock_thread_backend_sibling.socket")
.display()
.to_string();
let sibling_vsock_socket_path = test_dir
.path()
.join("test_vsock_thread_backend_sibling.vsock")
.display()
.to_string();
let cid_map: Arc<RwLock<CidMap>> = Arc::new(RwLock::new(HashMap::new()));
let sibling_config = VsockConfig::new(
SIBLING_CID,
sibling_vhost_socket_path,
sibling_vsock_socket_path,
CONN_TX_BUF_SIZE,
);
let sibling_backend =
Arc::new(VhostUserVsockBackend::new(sibling_config, cid_map.clone()).unwrap());
let epoll_fd = epoll::create(false).unwrap();
let mut vtp =
VsockThreadBackend::new(vsock_socket_path, epoll_fd, CID, CONN_TX_BUF_SIZE, cid_map);
assert!(!vtp.pending_raw_pkts());
let mut pkt_raw = [0u8; PKT_HEADER_SIZE + DATA_LEN];
let (hdr_raw, data_raw) = pkt_raw.split_at_mut(PKT_HEADER_SIZE);
let mut packet = unsafe { VsockPacket::new(hdr_raw, Some(data_raw)).unwrap() };
assert_eq!(
vtp.recv_raw_pkt(&mut packet).unwrap_err().to_string(),
Error::EmptyRawPktsQueue.to_string()
);
packet.set_type(VSOCK_TYPE_STREAM);
packet.set_src_cid(CID);
packet.set_dst_cid(SIBLING_CID);
packet.set_dst_port(SIBLING_LISTENING_PORT);
packet.set_op(VSOCK_OP_RW);
packet.set_len(DATA_LEN as u32);
packet
.data_slice()
.unwrap()
.copy_from(&[0xCAu8, 0xFEu8, 0xBAu8, 0xBEu8]);
assert!(vtp.send_pkt(&packet).is_ok());
assert!(sibling_backend.threads[0]
.lock()
.unwrap()
.thread_backend
.pending_raw_pkts());
let mut recvd_pkt_raw = [0u8; PKT_HEADER_SIZE + DATA_LEN];
let (recvd_hdr_raw, recvd_data_raw) = recvd_pkt_raw.split_at_mut(PKT_HEADER_SIZE);
let mut recvd_packet =
unsafe { VsockPacket::new(recvd_hdr_raw, Some(recvd_data_raw)).unwrap() };
assert!(sibling_backend.threads[0]
.lock()
.unwrap()
.thread_backend
.recv_raw_pkt(&mut recvd_packet)
.is_ok());
assert_eq!(recvd_packet.type_(), VSOCK_TYPE_STREAM);
assert_eq!(recvd_packet.src_cid(), CID);
assert_eq!(recvd_packet.dst_cid(), SIBLING_CID);
assert_eq!(recvd_packet.dst_port(), SIBLING_LISTENING_PORT);
assert_eq!(recvd_packet.op(), VSOCK_OP_RW);
assert_eq!(recvd_packet.len(), DATA_LEN as u32);
assert_eq!(recvd_data_raw[0], 0xCAu8);
assert_eq!(recvd_data_raw[1], 0xFEu8);
assert_eq!(recvd_data_raw[2], 0xBAu8);
assert_eq!(recvd_data_raw[3], 0xBEu8);
test_dir.close().unwrap();
}
}