use alloc::sync::Arc;
use core::task::Context;
use ax_errno::{AxError, AxResult, ax_bail, ax_err_type};
use ax_io::prelude::*;
use ax_sync::Mutex;
use axpoll::{IoEvents, Pollable};
use super::connection_manager::*;
use crate::{
RecvFlags, RecvOptions, SendOptions, Shutdown,
device::*,
general::GeneralOptions,
options::{Configurable, GetSocketOption, SetSocketOption},
state::*,
vsock::{VsockAddr, VsockConnId, VsockTransport, VsockTransportOps},
};
pub struct VsockStreamTransport {
conn_id: Mutex<Option<VsockConnId>>,
connection: Mutex<Option<Arc<Mutex<Connection>>>>,
state: StateLock,
general: GeneralOptions,
}
impl VsockStreamTransport {
pub fn new() -> Self {
Self {
conn_id: Mutex::new(None),
connection: Mutex::new(None),
state: StateLock::new(State::Idle),
general: GeneralOptions::new(),
}
}
fn get_connection(&self) -> AxResult<Arc<Mutex<Connection>>> {
self.connection.lock().clone().ok_or(AxError::NotConnected)
}
}
impl Default for VsockStreamTransport {
fn default() -> Self {
Self::new()
}
}
impl Configurable for VsockStreamTransport {
fn get_option_inner(&self, opt: &mut GetSocketOption) -> AxResult<bool> {
self.general.get_option_inner(opt)
}
fn set_option_inner(&self, opt: SetSocketOption) -> AxResult<bool> {
self.general.set_option_inner(opt)
}
}
impl VsockTransportOps for VsockStreamTransport {
fn bind(&self, mut local_addr: VsockAddr) -> AxResult<()> {
self.state
.lock(State::Idle)
.map_err(|_| ax_err_type!(InvalidInput, "already bound"))?
.transit(State::Idle, || {
let mut manager = VSOCK_CONN_MANAGER.lock();
if local_addr.port == 0 {
local_addr.port = manager.allocate_port()?;
}
let conn_id = VsockConnId::listening(local_addr.port);
let conn =
manager.create_connection(conn_id, local_addr, None, ConnectionState::Idle);
*self.conn_id.lock() = Some(conn_id);
*self.connection.lock() = Some(conn);
trace!("Vsock binding to {:?}", local_addr);
Ok(())
})?;
Ok(())
}
fn listen(&self) -> AxResult<()> {
let guard = self
.state
.lock(State::Idle)
.map_err(|_| ax_err_type!(InvalidInput, "invalid state for listen"))?;
guard.transit(State::Listening, || {
let conn = self.get_connection()?;
let local_addr = conn.lock().local_addr();
VSOCK_CONN_MANAGER.lock().listen(local_addr)?;
vsock_listen(local_addr)?;
conn.lock().set_state(ConnectionState::Listening);
trace!("Vsock listening on {:?}", local_addr);
Ok(())
})
}
fn accept(&self) -> AxResult<(VsockTransport, VsockAddr)> {
if self.state.get() != State::Listening {
ax_bail!(InvalidInput, "not listening");
}
let conn = self.get_connection()?;
let local_port = conn.lock().local_addr().port;
self.general.recv_poller(self, || {
let mut manager = VSOCK_CONN_MANAGER.lock();
if !manager.can_accept(local_port) {
return Err(AxError::WouldBlock);
}
let (conn_id, peer_addr) = manager.accept(local_port)?;
let conn = manager.get_connection(conn_id).ok_or(AxError::NotFound)?;
let new_transport = VsockStreamTransport {
conn_id: Mutex::new(Some(conn_id)),
connection: Mutex::new(Some(conn)),
state: StateLock::new(State::Connected),
general: GeneralOptions::default(),
};
Ok((VsockTransport::Stream(new_transport), peer_addr))
})
}
fn connect(&self, peer_addr: VsockAddr) -> AxResult<()> {
let guard = self.state.lock(State::Idle).map_err(|state| match state {
State::Idle => unreachable!(),
State::Listening => ax_err_type!(InvalidInput, "already listening"),
State::Connecting => ax_err_type!(InProgress),
State::Connected => ax_err_type!(AlreadyConnected),
_ => ax_err_type!(AlreadyConnected),
})?;
guard.transit(State::Connecting, || {
let mut manager = VSOCK_CONN_MANAGER.lock();
let existing_conn = self.connection.lock();
let local_port = if let Some(conn) = existing_conn.as_ref() {
let conn_guard = conn.lock();
match conn_guard.state() {
ConnectionState::Idle => {
conn_guard.local_addr().port
}
_ => {
ax_bail!(InvalidInput, "already connected or listening");
}
}
} else {
manager.allocate_port()?
};
drop(existing_conn);
let local_addr = VsockAddr {
cid: vsock_guest_cid()?,
port: local_port,
};
let conn_id = VsockConnId {
peer_addr,
local_port,
};
let conn = manager.create_connection(
conn_id,
local_addr,
Some(peer_addr),
ConnectionState::Connecting,
);
*self.conn_id.lock() = Some(conn_id);
*self.connection.lock() = Some(conn.clone());
drop(manager);
vsock_connect(conn_id)?;
debug!("Vsock connecting from {} to {:?}", local_port, peer_addr);
Ok(())
})?;
self.general.send_poller(self, || {
let conn = self.get_connection()?;
let state = conn.lock().state();
match state {
ConnectionState::Connected => Ok(()),
ConnectionState::Connecting => Err(AxError::WouldBlock),
_ => Err(ax_err_type!(ConnectionRefused)),
}
})
}
fn send(&self, mut src: impl Read + IoBuf, _options: SendOptions) -> AxResult<usize> {
let conn = self.get_connection()?;
let conn_guard = conn.lock();
if conn_guard.state() != ConnectionState::Connected {
return Err(AxError::NotConnected);
}
if conn_guard.tx_closed() {
return Err(AxError::NotConnected);
}
let conn_id = self.conn_id.lock().ok_or(AxError::NotConnected)?;
drop(conn_guard);
let result = src.write_to(&mut ax_io::write_fn(|buf| vsock_send(conn_id, buf)));
conn.lock().add_tx_bytes(result.unwrap_or(0));
result
}
fn recv(&self, mut dst: impl Write, options: RecvOptions) -> AxResult<usize> {
let conn = self.get_connection()?;
let extra_nb = options.flags.contains(RecvFlags::DONTWAIT);
self.general.recv_poller_with(self, extra_nb, || {
let mut conn_guard = conn.lock();
if conn_guard.rx_closed() && conn_guard.rx_buffer_used() == 0 {
return Ok(0); }
if !matches!(
conn_guard.state(),
ConnectionState::Connected | ConnectionState::Closed
) {
return Err(AxError::NotConnected);
}
if conn_guard.rx_buffer_used() == 0 {
return Err(AxError::WouldBlock);
}
let (left, right) = conn_guard.rx_slices();
let mut count = dst.write(left)?;
if count >= left.len() && !right.is_empty() {
count += dst.write(right)?;
}
if !options.flags.contains(RecvFlags::PEEK) {
conn_guard.advance_rx_read(count);
}
if count > 0 {
trace!(
"Recv {} bytes from connection (buffer_remaining={}/{})",
count,
conn_guard.rx_buffer_used(),
VSOCK_RX_BUFFER_SIZE
);
Ok(count)
} else {
Err(AxError::WouldBlock)
}
})
}
fn shutdown(&self, how: Shutdown) -> AxResult<()> {
let conn = self.get_connection()?;
let mut conn = conn.lock();
if how.has_read() {
conn.set_rx_closed(true);
}
if how.has_write() {
conn.set_tx_closed(true);
}
if let Some(conn_id) = *self.conn_id.lock() {
if conn.state() == ConnectionState::Connected {
vsock_disconnect(conn_id)?;
} else if conn.state() == ConnectionState::Listening {
VSOCK_CONN_MANAGER.lock().unlisten(conn_id.local_port);
}
}
conn.set_state(ConnectionState::Closed);
Ok(())
}
fn local_addr(&self) -> AxResult<Option<VsockAddr>> {
Ok(self
.get_connection()
.ok()
.map(|conn| conn.lock().local_addr()))
}
fn peer_addr(&self) -> AxResult<Option<VsockAddr>> {
Ok(self
.get_connection()
.ok()
.and_then(|conn| conn.lock().peer_addr()))
}
}
impl Pollable for VsockStreamTransport {
fn poll(&self) -> IoEvents {
let Ok(conn) = self.get_connection() else {
return IoEvents::empty();
};
let conn = conn.lock();
let mut events = IoEvents::empty();
match conn.state() {
ConnectionState::Listening => {
if let Some(conn_id) = *self.conn_id.lock() {
events.set(
IoEvents::IN,
VSOCK_CONN_MANAGER.lock().can_accept(conn_id.local_port),
);
}
}
ConnectionState::Connected | ConnectionState::Closed => {
events.set(IoEvents::IN, conn.rx_buffer_used() > 0 || conn.rx_closed());
events.set(IoEvents::OUT, !conn.tx_closed());
}
ConnectionState::Connecting => {
events.set(IoEvents::OUT, conn.state() == ConnectionState::Connected);
}
_ => {}
}
events.set(IoEvents::RDHUP, conn.rx_closed());
events
}
fn register(&self, context: &mut Context<'_>, events: IoEvents) {
if let Ok(conn) = self.get_connection() {
let mut conn = conn.lock();
match conn.state() {
ConnectionState::Listening if events.contains(IoEvents::IN) => {
conn.register_accept_poll(context);
}
ConnectionState::Connected => {
if events.contains(IoEvents::IN) {
conn.register_rx_poll(context);
}
if events.contains(IoEvents::OUT) {
warn!(
"VsockStreamTransport: OUT event on connected socket is not supported"
);
}
}
ConnectionState::Connecting if events.contains(IoEvents::OUT) => {
conn.register_connect_poll(context);
}
_ => {}
}
}
}
}
impl Drop for VsockStreamTransport {
fn drop(&mut self) {
let _ = self.shutdown(Shutdown::Both);
if let Some(conn_id) = *self.conn_id.lock() {
VSOCK_CONN_MANAGER.lock().remove_connection(conn_id);
}
}
}