use std::{
io::{ErrorKind, Read, Write},
num::Wrapping,
os::unix::prelude::{AsRawFd, RawFd},
};
use log::info;
use virtio_vsock::packet::{VsockPacket, PKT_HEADER_SIZE};
use vm_memory::{bitmap::BitmapSlice, Bytes, VolatileSlice};
use crate::{
rxops::*,
rxqueue::*,
txbuf::*,
vhu_vsock::{
Error, Result, VSOCK_FLAGS_SHUTDOWN_RCV, VSOCK_FLAGS_SHUTDOWN_SEND,
VSOCK_OP_CREDIT_REQUEST, VSOCK_OP_CREDIT_UPDATE, VSOCK_OP_REQUEST, VSOCK_OP_RESPONSE,
VSOCK_OP_RST, VSOCK_OP_RW, VSOCK_OP_SHUTDOWN, VSOCK_TYPE_STREAM,
},
vhu_vsock_thread::VhostUserVsockThread,
};
#[derive(Debug)]
pub(crate) struct VsockConnection<S> {
pub stream: S,
pub connect: bool,
pub peer_port: u32,
pub rx_queue: RxQueue,
local_cid: u64,
pub local_port: u32,
pub guest_cid: u64,
pub fwd_cnt: Wrapping<u32>,
last_fwd_cnt: Wrapping<u32>,
peer_buf_alloc: u32,
peer_fwd_cnt: Wrapping<u32>,
rx_cnt: Wrapping<u32>,
pub epoll_fd: RawFd,
pub tx_buf: LocalTxBuf,
tx_buffer_size: u32,
}
impl<S: AsRawFd + Read + Write> VsockConnection<S> {
pub fn new_local_init(
stream: S,
local_cid: u64,
local_port: u32,
guest_cid: u64,
guest_port: u32,
epoll_fd: RawFd,
tx_buffer_size: u32,
) -> Self {
Self {
stream,
connect: false,
peer_port: guest_port,
rx_queue: RxQueue::new(),
local_cid,
local_port,
guest_cid,
fwd_cnt: Wrapping(0),
last_fwd_cnt: Wrapping(0),
peer_buf_alloc: 0,
peer_fwd_cnt: Wrapping(0),
rx_cnt: Wrapping(0),
epoll_fd,
tx_buf: LocalTxBuf::new(tx_buffer_size),
tx_buffer_size,
}
}
#[allow(clippy::too_many_arguments)]
pub fn new_peer_init(
stream: S,
local_cid: u64,
local_port: u32,
guest_cid: u64,
guest_port: u32,
epoll_fd: RawFd,
peer_buf_alloc: u32,
tx_buffer_size: u32,
) -> Self {
let mut rx_queue = RxQueue::new();
rx_queue.enqueue(RxOps::Response);
Self {
stream,
connect: false,
peer_port: guest_port,
rx_queue,
local_cid,
local_port,
guest_cid,
fwd_cnt: Wrapping(0),
last_fwd_cnt: Wrapping(0),
peer_buf_alloc,
peer_fwd_cnt: Wrapping(0),
rx_cnt: Wrapping(0),
epoll_fd,
tx_buf: LocalTxBuf::new(tx_buffer_size),
tx_buffer_size,
}
}
pub fn set_peer_port(&mut self, peer_port: u32) {
self.peer_port = peer_port;
}
pub fn recv_pkt<B: BitmapSlice>(&mut self, pkt: &mut VsockPacket<B>) -> Result<()> {
self.init_pkt(pkt);
match self.rx_queue.dequeue() {
Some(RxOps::Request) => {
pkt.set_op(VSOCK_OP_REQUEST);
Ok(())
}
Some(RxOps::Rw) => {
if !self.connect {
pkt.set_op(VSOCK_OP_RST);
return Ok(());
}
if self.need_credit_update_from_peer() {
self.last_fwd_cnt = self.fwd_cnt;
pkt.set_op(VSOCK_OP_CREDIT_REQUEST);
return Ok(());
}
let buf = pkt.data_slice().ok_or(Error::PktBufMissing)?;
let max_read_len = std::cmp::min(buf.len(), self.peer_avail_credit());
if let Ok(read_cnt) = buf.read_from(0, &mut self.stream, max_read_len) {
if read_cnt == 0 {
pkt.set_op(VSOCK_OP_SHUTDOWN)
.set_flag(VSOCK_FLAGS_SHUTDOWN_RCV)
.set_flag(VSOCK_FLAGS_SHUTDOWN_SEND);
} else {
pkt.set_op(VSOCK_OP_RW).set_len(read_cnt as u32);
VhostUserVsockThread::epoll_register(
self.epoll_fd,
self.stream.as_raw_fd(),
epoll::Events::EPOLLIN | epoll::Events::EPOLLOUT,
)?;
}
self.rx_cnt += Wrapping(pkt.len());
self.last_fwd_cnt = self.fwd_cnt;
}
Ok(())
}
Some(RxOps::Response) => {
self.connect = true;
pkt.set_op(VSOCK_OP_RESPONSE);
Ok(())
}
Some(RxOps::CreditUpdate) => {
if !self.rx_queue.pending_rx() {
pkt.set_op(VSOCK_OP_CREDIT_UPDATE);
self.last_fwd_cnt = self.fwd_cnt;
}
Ok(())
}
_ => Err(Error::NoRequestRx),
}
}
pub fn send_pkt<B: BitmapSlice>(&mut self, pkt: &VsockPacket<B>) -> Result<()> {
self.peer_buf_alloc = pkt.buf_alloc();
self.peer_fwd_cnt = Wrapping(pkt.fwd_cnt());
match pkt.op() {
VSOCK_OP_RESPONSE => {
let response = format!("OK {}\n", self.peer_port);
self.stream.write_all(response.as_bytes()).unwrap();
self.connect = true;
}
VSOCK_OP_RW => {
match pkt.data_slice() {
None => {
info!(
"Dropping empty packet from guest (lp={}, pp={})",
self.local_port, self.peer_port
);
return Ok(());
}
Some(buf) => {
if let Err(err) = self.send_bytes(buf) {
dbg!("err:{:?}", err);
return Ok(());
}
}
}
}
VSOCK_OP_CREDIT_UPDATE => {
if VhostUserVsockThread::epoll_modify(
self.epoll_fd,
self.stream.as_raw_fd(),
epoll::Events::EPOLLIN | epoll::Events::EPOLLOUT,
)
.is_err()
{
VhostUserVsockThread::epoll_register(
self.epoll_fd,
self.stream.as_raw_fd(),
epoll::Events::EPOLLIN | epoll::Events::EPOLLOUT,
)
.unwrap();
};
}
VSOCK_OP_CREDIT_REQUEST => {
self.rx_queue.enqueue(RxOps::CreditUpdate);
}
VSOCK_OP_SHUTDOWN => {
let recv_off = pkt.flags() & VSOCK_FLAGS_SHUTDOWN_RCV != 0;
let send_off = pkt.flags() & VSOCK_FLAGS_SHUTDOWN_SEND != 0;
if recv_off && send_off && self.tx_buf.is_empty() {
self.rx_queue.enqueue(RxOps::Reset);
}
}
_ => {}
}
Ok(())
}
fn send_bytes<B: BitmapSlice>(&mut self, buf: &VolatileSlice<B>) -> Result<()> {
if !self.tx_buf.is_empty() {
return self.tx_buf.push(buf);
}
let written_count = match buf.write_to(0, &mut self.stream, buf.len()) {
Ok(cnt) => cnt,
Err(vm_memory::VolatileMemoryError::IOError(e)) => {
if e.kind() == ErrorKind::WouldBlock {
0
} else {
dbg!("send_bytes error: {:?}", e);
return Err(Error::UnixWrite);
}
}
Err(e) => {
dbg!("send_bytes error: {:?}", e);
return Err(Error::UnixWrite);
}
};
if written_count > 0 {
self.fwd_cnt += Wrapping(written_count as u32);
let free_space = self
.tx_buffer_size
.wrapping_sub((self.fwd_cnt - self.last_fwd_cnt).0);
if free_space < self.tx_buffer_size / 4 {
self.rx_queue.enqueue(RxOps::CreditUpdate);
}
}
if written_count != buf.len() {
return self.tx_buf.push(&buf.offset(written_count).unwrap());
}
Ok(())
}
fn init_pkt<'a, 'b, B: BitmapSlice>(
&self,
pkt: &'a mut VsockPacket<'b, B>,
) -> &'a mut VsockPacket<'b, B> {
pkt.set_header_from_raw(&[0u8; PKT_HEADER_SIZE]).unwrap();
pkt.set_src_cid(self.local_cid)
.set_dst_cid(self.guest_cid)
.set_src_port(self.local_port)
.set_dst_port(self.peer_port)
.set_type(VSOCK_TYPE_STREAM)
.set_buf_alloc(self.tx_buffer_size)
.set_fwd_cnt(self.fwd_cnt.0)
}
fn peer_avail_credit(&self) -> usize {
(Wrapping(self.peer_buf_alloc) - (self.rx_cnt - self.peer_fwd_cnt)).0 as usize
}
fn need_credit_update_from_peer(&self) -> bool {
self.peer_avail_credit() == 0
}
}
#[cfg(test)]
mod tests {
use byteorder::{ByteOrder, LittleEndian};
use super::*;
use crate::vhu_vsock::{VSOCK_HOST_CID, VSOCK_OP_RW, VSOCK_TYPE_STREAM};
use std::io::Result as IoResult;
use std::ops::Deref;
use virtio_bindings::bindings::virtio_ring::{VRING_DESC_F_NEXT, VRING_DESC_F_WRITE};
use virtio_queue::{mock::MockSplitQueue, Descriptor, DescriptorChain, Queue, QueueOwnedT};
use vm_memory::{
Address, Bytes, GuestAddress, GuestAddressSpace, GuestMemoryAtomic, GuestMemoryLoadGuard,
GuestMemoryMmap,
};
const CONN_TX_BUF_SIZE: u32 = 64 * 1024;
struct HeadParams {
head_len: usize,
data_len: u32,
}
impl HeadParams {
fn new(head_len: usize, data_len: u32) -> Self {
Self { head_len, data_len }
}
fn construct_head(&self) -> Vec<u8> {
let mut header = vec![0_u8; self.head_len];
if self.head_len == PKT_HEADER_SIZE {
const HDROFF_LEN: usize = 24;
LittleEndian::write_u32(&mut header[HDROFF_LEN..], self.data_len);
}
header
}
}
fn prepare_desc_chain_vsock(
write_only: bool,
head_params: &HeadParams,
data_chain_len: u16,
head_data_len: u32,
) -> (
GuestMemoryAtomic<GuestMemoryMmap>,
DescriptorChain<GuestMemoryLoadGuard<GuestMemoryMmap>>,
) {
let mem = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x1000)]).unwrap();
let virt_queue = MockSplitQueue::new(&mem, 16);
let mut next_addr = virt_queue.desc_table().total_size() + 0x100;
let mut flags = 0;
if write_only {
flags |= VRING_DESC_F_WRITE;
}
let mut head_flags = if data_chain_len > 0 {
flags | VRING_DESC_F_NEXT
} else {
flags
};
let header = head_params.construct_head();
let head_desc =
Descriptor::new(next_addr, head_params.head_len as u32, head_flags as u16, 1);
mem.write(&header, head_desc.addr()).unwrap();
assert!(virt_queue.desc_table().store(0, head_desc).is_ok());
next_addr += head_params.head_len as u64;
mem.write_obj(0u16, virt_queue.avail_addr().unchecked_add(4))
.unwrap();
mem.write_obj(1u16, virt_queue.avail_addr().unchecked_add(2))
.unwrap();
for i in 0..(data_chain_len) {
if i == data_chain_len - 1 {
head_flags &= !VRING_DESC_F_NEXT;
}
let data = vec![0_u8; head_data_len as usize];
let data_desc = Descriptor::new(next_addr, data.len() as u32, head_flags as u16, i + 2);
mem.write(&data, data_desc.addr()).unwrap();
assert!(virt_queue.desc_table().store(i + 1, data_desc).is_ok());
next_addr += head_data_len as u64;
}
(
GuestMemoryAtomic::new(mem.clone()),
virt_queue
.create_queue::<Queue>()
.unwrap()
.iter(GuestMemoryAtomic::new(mem.clone()).memory())
.unwrap()
.next()
.unwrap(),
)
}
struct VsockDummySocket {
data: Vec<u8>,
}
impl VsockDummySocket {
fn new() -> Self {
Self { data: Vec::new() }
}
}
impl Write for VsockDummySocket {
fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, std::io::Error> {
self.data.clear();
self.data.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> IoResult<()> {
Ok(())
}
}
impl Read for VsockDummySocket {
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
buf[..self.data.len()].copy_from_slice(&self.data);
Ok(self.data.len())
}
}
impl AsRawFd for VsockDummySocket {
fn as_raw_fd(&self) -> RawFd {
-1
}
}
#[test]
fn test_vsock_conn_init() {
let dummy_file = VsockDummySocket::new();
let mut conn_local = VsockConnection::new_local_init(
dummy_file,
VSOCK_HOST_CID,
5000,
3,
5001,
-1,
CONN_TX_BUF_SIZE,
);
assert!(!conn_local.connect);
assert_eq!(conn_local.peer_port, 5001);
assert_eq!(conn_local.rx_queue, RxQueue::new());
assert_eq!(conn_local.local_cid, VSOCK_HOST_CID);
assert_eq!(conn_local.local_port, 5000);
assert_eq!(conn_local.guest_cid, 3);
conn_local.set_peer_port(5002);
assert_eq!(conn_local.peer_port, 5002);
let dummy_file = VsockDummySocket::new();
let mut conn_peer = VsockConnection::new_peer_init(
dummy_file,
VSOCK_HOST_CID,
5000,
3,
5001,
-1,
65536,
CONN_TX_BUF_SIZE,
);
assert!(!conn_peer.connect);
assert_eq!(conn_peer.peer_port, 5001);
assert_eq!(conn_peer.rx_queue.dequeue().unwrap(), RxOps::Response);
assert!(!conn_peer.rx_queue.pending_rx());
assert_eq!(conn_peer.local_cid, VSOCK_HOST_CID);
assert_eq!(conn_peer.local_port, 5000);
assert_eq!(conn_peer.guest_cid, 3);
assert_eq!(conn_peer.peer_buf_alloc, 65536);
}
#[test]
fn test_vsock_conn_credit() {
let dummy_file = VsockDummySocket::new();
let mut conn_local = VsockConnection::new_local_init(
dummy_file,
VSOCK_HOST_CID,
5000,
3,
5001,
-1,
CONN_TX_BUF_SIZE,
);
assert_eq!(conn_local.peer_avail_credit(), 0);
assert!(conn_local.need_credit_update_from_peer());
conn_local.peer_buf_alloc = 65536;
assert_eq!(conn_local.peer_avail_credit(), 65536);
assert!(!conn_local.need_credit_update_from_peer());
conn_local.rx_cnt = Wrapping(32768);
assert_eq!(conn_local.peer_avail_credit(), 32768);
assert!(!conn_local.need_credit_update_from_peer());
conn_local.rx_cnt = Wrapping(65536);
assert_eq!(conn_local.peer_avail_credit(), 0);
assert!(conn_local.need_credit_update_from_peer());
}
#[test]
fn test_vsock_conn_init_pkt() {
let head_params = HeadParams::new(PKT_HEADER_SIZE, 10);
let dummy_file = VsockDummySocket::new();
let conn_local = VsockConnection::new_local_init(
dummy_file,
VSOCK_HOST_CID,
5000,
3,
5001,
-1,
CONN_TX_BUF_SIZE,
);
let (mem, mut descr_chain) = prepare_desc_chain_vsock(true, &head_params, 2, 10);
let mem = mem.memory();
let mut pkt =
VsockPacket::from_rx_virtq_chain(mem.deref(), &mut descr_chain, CONN_TX_BUF_SIZE)
.unwrap();
conn_local.init_pkt(&mut pkt);
assert_eq!(pkt.src_cid(), VSOCK_HOST_CID);
assert_eq!(pkt.dst_cid(), 3);
assert_eq!(pkt.src_port(), 5000);
assert_eq!(pkt.dst_port(), 5001);
assert_eq!(pkt.type_(), VSOCK_TYPE_STREAM);
assert_eq!(pkt.buf_alloc(), CONN_TX_BUF_SIZE);
assert_eq!(pkt.fwd_cnt(), 0);
}
#[test]
fn test_vsock_conn_recv_pkt() {
let head_params = HeadParams::new(PKT_HEADER_SIZE, 5);
let dummy_file = VsockDummySocket::new();
let mut conn_local = VsockConnection::new_local_init(
dummy_file,
VSOCK_HOST_CID,
5000,
3,
5001,
-1,
CONN_TX_BUF_SIZE,
);
let (mem, mut descr_chain) = prepare_desc_chain_vsock(true, &head_params, 1, 5);
let mem = mem.memory();
let mut pkt =
VsockPacket::from_rx_virtq_chain(mem.deref(), &mut descr_chain, CONN_TX_BUF_SIZE)
.unwrap();
conn_local.rx_queue.enqueue(RxOps::Request);
let op_req = conn_local.recv_pkt(&mut pkt);
assert!(op_req.is_ok());
assert!(!conn_local.rx_queue.pending_rx());
assert_eq!(pkt.op(), VSOCK_OP_REQUEST);
conn_local.rx_queue.enqueue(RxOps::Rw);
let op_rst = conn_local.recv_pkt(&mut pkt);
assert!(op_rst.is_ok());
assert!(!conn_local.rx_queue.pending_rx());
assert_eq!(pkt.op(), VSOCK_OP_RST);
conn_local.connect = true;
conn_local.rx_queue.enqueue(RxOps::Rw);
conn_local.fwd_cnt = Wrapping(1024);
let op_credit_update = conn_local.recv_pkt(&mut pkt);
assert!(op_credit_update.is_ok());
assert!(!conn_local.rx_queue.pending_rx());
assert_eq!(pkt.op(), VSOCK_OP_CREDIT_REQUEST);
assert_eq!(conn_local.last_fwd_cnt, Wrapping(1024));
conn_local.peer_buf_alloc = 65536;
conn_local.rx_queue.enqueue(RxOps::Rw);
let op_zero_read_shutdown = conn_local.recv_pkt(&mut pkt);
assert!(op_zero_read_shutdown.is_ok());
assert!(!conn_local.rx_queue.pending_rx());
assert_eq!(conn_local.rx_cnt, Wrapping(0));
assert_eq!(conn_local.last_fwd_cnt, Wrapping(1024));
assert_eq!(pkt.op(), VSOCK_OP_SHUTDOWN);
assert_eq!(
pkt.flags(),
VSOCK_FLAGS_SHUTDOWN_RCV | VSOCK_FLAGS_SHUTDOWN_SEND
);
conn_local.stream.write_all(b"hello").unwrap();
conn_local.rx_queue.enqueue(RxOps::Rw);
let op_zero_read = conn_local.recv_pkt(&mut pkt);
assert!(op_zero_read.is_err());
assert_eq!(pkt.op(), VSOCK_OP_RW);
assert!(!conn_local.rx_queue.pending_rx());
assert_eq!(pkt.len(), 5);
let buf = &mut [0u8; 5];
assert!(pkt.data_slice().unwrap().read_slice(buf, 0).is_ok());
assert_eq!(buf, b"hello");
conn_local.rx_queue.enqueue(RxOps::Response);
let op_response = conn_local.recv_pkt(&mut pkt);
assert!(op_response.is_ok());
assert!(!conn_local.rx_queue.pending_rx());
assert_eq!(pkt.op(), VSOCK_OP_RESPONSE);
assert!(conn_local.connect);
conn_local.rx_queue.enqueue(RxOps::CreditUpdate);
let op_credit_update = conn_local.recv_pkt(&mut pkt);
assert!(!conn_local.rx_queue.pending_rx());
assert!(op_credit_update.is_ok());
assert_eq!(pkt.op(), VSOCK_OP_CREDIT_UPDATE);
assert_eq!(conn_local.last_fwd_cnt, Wrapping(1024));
let op_error = conn_local.recv_pkt(&mut pkt);
assert!(op_error.is_err());
}
#[test]
fn test_vsock_conn_send_pkt() {
let head_params = HeadParams::new(PKT_HEADER_SIZE, 5);
let dummy_file = VsockDummySocket::new();
let mut conn_local = VsockConnection::new_local_init(
dummy_file,
VSOCK_HOST_CID,
5000,
3,
5001,
-1,
CONN_TX_BUF_SIZE,
);
let (mem, mut descr_chain) = prepare_desc_chain_vsock(false, &head_params, 1, 5);
let mem = mem.memory();
let mut pkt =
VsockPacket::from_tx_virtq_chain(mem.deref(), &mut descr_chain, CONN_TX_BUF_SIZE)
.unwrap();
pkt.set_buf_alloc(65536).set_fwd_cnt(1024);
let credit_check = conn_local.send_pkt(&pkt);
assert!(credit_check.is_ok());
assert_eq!(conn_local.peer_buf_alloc, 65536);
assert_eq!(conn_local.peer_fwd_cnt, Wrapping(1024));
pkt.set_op(VSOCK_OP_RESPONSE);
let peer_response = conn_local.send_pkt(&pkt);
assert!(peer_response.is_ok());
assert!(conn_local.connect);
let mut resp_buf = vec![0; 8];
conn_local.stream.read_exact(&mut resp_buf).unwrap();
assert_eq!(resp_buf, b"OK 5001\n");
pkt.set_op(VSOCK_OP_RW);
let buf = b"hello";
assert!(pkt.data_slice().unwrap().write_slice(buf, 0).is_ok());
let rw_response = conn_local.send_pkt(&pkt);
assert!(rw_response.is_ok());
let mut resp_buf = vec![0; 5];
conn_local.stream.read_exact(&mut resp_buf).unwrap();
assert_eq!(resp_buf, b"hello");
pkt.set_op(VSOCK_OP_CREDIT_REQUEST);
let credit_response = conn_local.send_pkt(&pkt);
assert!(credit_response.is_ok());
assert_eq!(conn_local.rx_queue.peek().unwrap(), RxOps::CreditUpdate);
pkt.set_op(VSOCK_OP_SHUTDOWN);
pkt.set_flags(VSOCK_FLAGS_SHUTDOWN_RCV | VSOCK_FLAGS_SHUTDOWN_SEND);
let shutdown_response = conn_local.send_pkt(&pkt);
assert!(shutdown_response.is_ok());
assert!(conn_local.rx_queue.contains(RxOps::Reset.bitmask()));
}
}