use super::DEFAULT_RX_BUFFER_SIZE;
use super::error::SocketError;
use super::protocol::{
Feature, StreamShutdown, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp, VsockAddr,
};
use crate::Result;
use crate::config::read_config;
use crate::hal::Hal;
use crate::queue::{OwningQueue, VirtQueue};
use crate::transport::Transport;
use core::mem::size_of;
use log::debug;
use zerocopy::{FromBytes, IntoBytes};
pub(crate) const RX_QUEUE_IDX: u16 = 0;
pub(crate) const TX_QUEUE_IDX: u16 = 1;
const EVENT_QUEUE_IDX: u16 = 2;
pub(crate) const QUEUE_SIZE: usize = 8;
const SUPPORTED_FEATURES: Feature = Feature::RING_EVENT_IDX
.union(Feature::RING_INDIRECT_DESC)
.union(Feature::VERSION_1);
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct ConnectionInfo {
pub dst: VsockAddr,
pub src_port: u32,
peer_buf_alloc: u32,
peer_fwd_cnt: u32,
tx_cnt: u32,
pub buf_alloc: u32,
fwd_cnt: u32,
has_pending_credit_request: bool,
}
impl ConnectionInfo {
pub fn new(destination: VsockAddr, src_port: u32) -> Self {
Self {
dst: destination,
src_port,
..Default::default()
}
}
pub fn update_for_event(&mut self, event: &VsockEvent) {
self.peer_buf_alloc = event.buffer_status.buffer_allocation;
self.peer_fwd_cnt = event.buffer_status.forward_count;
if let VsockEventType::CreditUpdate = event.event_type {
self.has_pending_credit_request = false;
}
}
pub fn done_forwarding(&mut self, length: usize) {
self.fwd_cnt += length as u32;
}
fn peer_free(&self) -> u32 {
self.peer_buf_alloc - (self.tx_cnt - self.peer_fwd_cnt)
}
fn new_header(&self, src_cid: u64) -> VirtioVsockHdr {
VirtioVsockHdr {
src_cid: src_cid.into(),
dst_cid: self.dst.cid.into(),
src_port: self.src_port.into(),
dst_port: self.dst.port.into(),
buf_alloc: self.buf_alloc.into(),
fwd_cnt: self.fwd_cnt.into(),
..Default::default()
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct VsockEvent {
pub source: VsockAddr,
pub destination: VsockAddr,
pub buffer_status: VsockBufferStatus,
pub event_type: VsockEventType,
}
impl VsockEvent {
pub fn matches_connection(&self, connection_info: &ConnectionInfo, guest_cid: u64) -> bool {
self.source == connection_info.dst
&& self.destination.cid == guest_cid
&& self.destination.port == connection_info.src_port
}
fn from_header(header: &VirtioVsockHdr) -> Result<Self> {
let op = header.op()?;
let buffer_status = VsockBufferStatus {
buffer_allocation: header.buf_alloc.into(),
forward_count: header.fwd_cnt.into(),
};
let source = header.source();
let destination = header.destination();
let event_type = match op {
VirtioVsockOp::Request => {
header.check_data_is_empty()?;
VsockEventType::ConnectionRequest
}
VirtioVsockOp::Response => {
header.check_data_is_empty()?;
VsockEventType::Connected
}
VirtioVsockOp::CreditUpdate => {
header.check_data_is_empty()?;
VsockEventType::CreditUpdate
}
VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => {
header.check_data_is_empty()?;
debug!("Disconnected from the peer");
let reason = if op == VirtioVsockOp::Rst {
DisconnectReason::Reset
} else {
DisconnectReason::Shutdown
};
VsockEventType::Disconnected { reason }
}
VirtioVsockOp::Rw => VsockEventType::Received {
length: header.len() as usize,
},
VirtioVsockOp::CreditRequest => {
header.check_data_is_empty()?;
VsockEventType::CreditRequest
}
VirtioVsockOp::Invalid => return Err(SocketError::InvalidOperation.into()),
};
Ok(VsockEvent {
source,
destination,
buffer_status,
event_type,
})
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct VsockBufferStatus {
pub buffer_allocation: u32,
pub forward_count: u32,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum DisconnectReason {
Reset,
Shutdown,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum VsockEventType {
ConnectionRequest,
Connected,
Disconnected {
reason: DisconnectReason,
},
Received {
length: usize,
},
CreditRequest,
CreditUpdate,
}
pub struct VirtIOSocket<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize = DEFAULT_RX_BUFFER_SIZE>
{
transport: T,
rx: OwningQueue<H, QUEUE_SIZE, RX_BUFFER_SIZE>,
tx: VirtQueue<H, { QUEUE_SIZE }>,
event: VirtQueue<H, { QUEUE_SIZE }>,
guest_cid: u64,
}
impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> Drop
for VirtIOSocket<H, T, RX_BUFFER_SIZE>
{
fn drop(&mut self) {
self.transport.queue_unset(RX_QUEUE_IDX);
self.transport.queue_unset(TX_QUEUE_IDX);
self.transport.queue_unset(EVENT_QUEUE_IDX);
}
}
impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> VirtIOSocket<H, T, RX_BUFFER_SIZE> {
pub fn new(mut transport: T) -> Result<Self> {
assert!(RX_BUFFER_SIZE > size_of::<VirtioVsockHdr>());
let negotiated_features = transport.begin_init(SUPPORTED_FEATURES);
let guest_cid = transport.read_consistent(|| {
Ok(
(read_config!(transport, VirtioVsockConfig, guest_cid_low)? as u64)
| ((read_config!(transport, VirtioVsockConfig, guest_cid_high)? as u64) << 32),
)
})?;
debug!("guest cid: {guest_cid:?}");
let rx = VirtQueue::new(
&mut transport,
RX_QUEUE_IDX,
negotiated_features.contains(Feature::RING_INDIRECT_DESC),
negotiated_features.contains(Feature::RING_EVENT_IDX),
)?;
let tx = VirtQueue::new(
&mut transport,
TX_QUEUE_IDX,
negotiated_features.contains(Feature::RING_INDIRECT_DESC),
negotiated_features.contains(Feature::RING_EVENT_IDX),
)?;
let event = VirtQueue::new(
&mut transport,
EVENT_QUEUE_IDX,
negotiated_features.contains(Feature::RING_INDIRECT_DESC),
negotiated_features.contains(Feature::RING_EVENT_IDX),
)?;
let rx = OwningQueue::new(rx)?;
transport.finish_init();
if rx.should_notify() {
transport.notify(RX_QUEUE_IDX);
}
Ok(Self {
transport,
rx,
tx,
event,
guest_cid,
})
}
pub fn guest_cid(&self) -> u64 {
self.guest_cid
}
pub fn connect(&mut self, connection_info: &ConnectionInfo) -> Result {
let header = VirtioVsockHdr {
op: VirtioVsockOp::Request.into(),
..connection_info.new_header(self.guest_cid)
};
self.send_packet_to_tx_queue(&header, &[])
}
pub fn accept(&mut self, connection_info: &ConnectionInfo) -> Result {
let header = VirtioVsockHdr {
op: VirtioVsockOp::Response.into(),
..connection_info.new_header(self.guest_cid)
};
self.send_packet_to_tx_queue(&header, &[])
}
fn request_credit(&mut self, connection_info: &ConnectionInfo) -> Result {
let header = VirtioVsockHdr {
op: VirtioVsockOp::CreditRequest.into(),
..connection_info.new_header(self.guest_cid)
};
self.send_packet_to_tx_queue(&header, &[])
}
pub fn send(&mut self, buffer: &[u8], connection_info: &mut ConnectionInfo) -> Result {
self.check_peer_buffer_is_sufficient(connection_info, buffer.len())?;
let len = buffer.len() as u32;
let header = VirtioVsockHdr {
op: VirtioVsockOp::Rw.into(),
len: len.into(),
..connection_info.new_header(self.guest_cid)
};
connection_info.tx_cnt += len;
self.send_packet_to_tx_queue(&header, buffer)
}
fn check_peer_buffer_is_sufficient(
&mut self,
connection_info: &mut ConnectionInfo,
buffer_len: usize,
) -> Result {
if connection_info.peer_free() as usize >= buffer_len {
Ok(())
} else {
if !connection_info.has_pending_credit_request {
self.request_credit(connection_info)?;
connection_info.has_pending_credit_request = true;
}
Err(SocketError::InsufficientBufferSpaceInPeer.into())
}
}
pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result {
let header = VirtioVsockHdr {
op: VirtioVsockOp::CreditUpdate.into(),
..connection_info.new_header(self.guest_cid)
};
self.send_packet_to_tx_queue(&header, &[])
}
pub fn poll(
&mut self,
handler: impl FnOnce(VsockEvent, &[u8]) -> Result<Option<VsockEvent>>,
) -> Result<Option<VsockEvent>> {
self.rx.poll(&mut self.transport, |buffer| {
let (header, body) = read_header_and_body(buffer)?;
VsockEvent::from_header(&header).and_then(|event| handler(event, body))
})
}
pub fn shutdown_with_hints(
&mut self,
connection_info: &ConnectionInfo,
hints: StreamShutdown,
) -> Result {
let header = VirtioVsockHdr {
op: VirtioVsockOp::Shutdown.into(),
flags: hints.into(),
..connection_info.new_header(self.guest_cid)
};
self.send_packet_to_tx_queue(&header, &[])
}
pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result {
self.shutdown_with_hints(
connection_info,
StreamShutdown::SEND | StreamShutdown::RECEIVE,
)
}
pub fn force_close(&mut self, connection_info: &ConnectionInfo) -> Result {
let header = VirtioVsockHdr {
op: VirtioVsockOp::Rst.into(),
..connection_info.new_header(self.guest_cid)
};
self.send_packet_to_tx_queue(&header, &[])?;
Ok(())
}
fn send_packet_to_tx_queue(&mut self, header: &VirtioVsockHdr, buffer: &[u8]) -> Result {
let _len = if buffer.is_empty() {
self.tx
.add_notify_wait_pop(&[header.as_bytes()], &mut [], &mut self.transport)?
} else {
self.tx.add_notify_wait_pop(
&[header.as_bytes(), buffer],
&mut [],
&mut self.transport,
)?
};
Ok(())
}
}
fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8])> {
let header = VirtioVsockHdr::read_from_prefix(buffer)
.map_err(|_| SocketError::BufferTooShort)?
.0;
let body_length = header.len() as usize;
let data_end = size_of::<VirtioVsockHdr>()
.checked_add(body_length)
.ok_or(SocketError::InvalidNumber)?;
let data = buffer
.get(size_of::<VirtioVsockHdr>()..data_end)
.ok_or(SocketError::BufferTooShort)?;
Ok((header, data))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
config::ReadOnly,
hal::fake::FakeHal,
transport::{
DeviceType,
fake::{FakeTransport, QueueStatus, State},
},
};
use alloc::{sync::Arc, vec};
use std::sync::Mutex;
#[test]
fn config() {
let config_space = VirtioVsockConfig {
guest_cid_low: ReadOnly::new(66),
guest_cid_high: ReadOnly::new(0),
};
let state = Arc::new(Mutex::new(State::new(
vec![
QueueStatus::default(),
QueueStatus::default(),
QueueStatus::default(),
],
config_space,
)));
let transport = FakeTransport {
device_type: DeviceType::Socket,
max_queue_size: 32,
device_features: 0,
state: state.clone(),
};
let socket =
VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap();
assert_eq!(socket.guest_cid(), 0x00_0000_0042);
}
}