use crate::delegate::{
IncomingConnection, ListenerHandle, create_delegate_instance, register_listener,
unregister_listener,
};
use crate::error::{VZError, VZResult};
use crate::ffi::block::{
_Block_release, VsockResult, create_blocking_vsock_context_block, create_vsock_context_block,
};
use crate::ffi::get_class;
use crate::msg_send;
use objc2::runtime::AnyObject;
use std::ffi::c_void;
use std::os::unix::io::RawFd;
use std::sync::mpsc as std_mpsc;
use std::time::Duration;
use tokio::sync::{mpsc, oneshot};
unsafe extern "C" {
fn dispatch_async_f(
queue: *mut AnyObject,
context: *mut c_void,
work: unsafe extern "C" fn(*mut c_void),
);
}
struct ConnectContext {
device: *mut AnyObject,
port: u32,
block: *const c_void,
}
unsafe impl Send for ConnectContext {}
unsafe extern "C" fn connect_work(ctx: *mut c_void) {
unsafe {
let context = Box::from_raw(ctx as *mut ConnectContext);
tracing::debug!(
"connect_work: calling connectToPort:{} on device {:?}",
context.port,
context.device
);
let sel = objc2::sel!(connectToPort:completionHandler:);
let func: unsafe extern "C" fn(*mut AnyObject, objc2::runtime::Sel, u32, *const c_void) =
std::mem::transmute(crate::ffi::runtime::objc_msgSend as *const c_void);
func(context.device, sel, context.port, context.block);
}
}
pub struct VirtioSocketDevice {
inner: *mut AnyObject,
queue: *mut AnyObject,
}
unsafe impl Send for VirtioSocketDevice {}
unsafe impl Sync for VirtioSocketDevice {}
impl VirtioSocketDevice {
pub(crate) fn from_raw(ptr: *mut AnyObject, queue: *mut AnyObject) -> Self {
Self { inner: ptr, queue }
}
pub async fn connect(&self, port: u32) -> VZResult<VirtioSocketConnection> {
tracing::debug!("VirtioSocketDevice::connect(port={})", port);
let (tx, rx) = oneshot::channel::<VsockResult>();
let block = create_vsock_context_block(tx);
tracing::debug!("Created vsock context block: {:?}", block);
let context = Box::new(ConnectContext {
device: self.inner,
port,
block,
});
let context_ptr = Box::into_raw(context);
unsafe {
tracing::debug!("Dispatching connect to VM queue {:?}", self.queue);
dispatch_async_f(self.queue, context_ptr as *mut c_void, connect_work);
}
let timeout = Duration::from_secs(10);
match tokio::time::timeout(timeout, rx).await {
Ok(Ok(result)) => {
unsafe {
_Block_release(block);
}
match result {
Ok(info) => {
tracing::info!(
"Vsock connected: fd={}, src_port={}, dst_port={}",
info.fd,
info.source_port,
info.destination_port
);
Ok(VirtioSocketConnection {
fd: info.fd,
source_port: info.source_port,
destination_port: info.destination_port,
})
}
Err(e) => {
if is_transient_connect_error(&e.message) {
tracing::debug!(
port,
error = %e.message,
"Vsock connection not ready yet"
);
} else {
tracing::warn!(
port,
error = %e.message,
"Vsock connection failed"
);
}
Err(VZError::ConnectionFailed(e.message))
}
}
}
Ok(Err(_)) => {
unsafe {
_Block_release(block);
}
Err(VZError::Internal {
code: -1,
message: "Connection channel closed unexpectedly".into(),
})
}
Err(_) => {
tracing::warn!("Vsock connection timed out after {:?}", timeout);
Err(VZError::Timeout(format!(
"Vsock connection to port {port} timed out"
)))
}
}
}
pub fn connect_blocking(
&self,
port: u32,
timeout: Duration,
) -> VZResult<VirtioSocketConnection> {
tracing::debug!("VirtioSocketDevice::connect_blocking(port={})", port);
let (tx, rx) = std_mpsc::channel::<VsockResult>();
let block = create_blocking_vsock_context_block(tx);
let context = Box::new(ConnectContext {
device: self.inner,
port,
block,
});
let context_ptr = Box::into_raw(context);
unsafe {
tracing::debug!("Dispatching blocking connect to VM queue {:?}", self.queue);
dispatch_async_f(self.queue, context_ptr as *mut c_void, connect_work);
}
match rx.recv_timeout(timeout) {
Ok(Ok(info)) => {
unsafe {
_Block_release(block);
}
tracing::info!(
"Vsock connected: fd={}, src_port={}, dst_port={}",
info.fd,
info.source_port,
info.destination_port
);
Ok(VirtioSocketConnection {
fd: info.fd,
source_port: info.source_port,
destination_port: info.destination_port,
})
}
Ok(Err(e)) => {
unsafe {
_Block_release(block);
}
if is_transient_connect_error(&e.message) {
tracing::debug!(
port,
error = %e.message,
"Vsock connection not ready yet"
);
} else {
tracing::warn!(
port,
error = %e.message,
"Vsock connection failed"
);
}
Err(VZError::ConnectionFailed(e.message))
}
Err(std_mpsc::RecvTimeoutError::Timeout) => {
tracing::warn!("Vsock connection timed out after {:?}", timeout);
Err(VZError::Timeout(format!(
"Vsock connection to port {port} timed out"
)))
}
Err(std_mpsc::RecvTimeoutError::Disconnected) => {
unsafe {
_Block_release(block);
}
Err(VZError::Internal {
code: -1,
message: "Connection channel closed unexpectedly".into(),
})
}
}
}
pub fn listen(&self, port: u32) -> VZResult<VirtioSocketListener> {
tracing::debug!("VirtioSocketDevice::listen(port={})", port);
unsafe {
let listener_cls =
get_class("VZVirtioSocketListener").ok_or_else(|| VZError::Internal {
code: -1,
message: "VZVirtioSocketListener class not found".into(),
})?;
let listener_obj: *mut AnyObject = msg_send!(listener_cls, new);
if listener_obj.is_null() {
return Err(VZError::Internal {
code: -1,
message: "Failed to create VZVirtioSocketListener".into(),
});
}
let (tx, rx) = mpsc::unbounded_channel::<IncomingConnection>();
let handle = register_listener(tx);
let delegate = match create_delegate_instance(handle) {
Ok(d) => d,
Err(e) => {
unregister_listener(handle);
return Err(VZError::Internal {
code: -1,
message: format!("Failed to create delegate instance: {e}"),
});
}
};
tracing::debug!(
"Setting delegate {:?} on listener {:?}",
delegate,
listener_obj
);
let set_delegate_sel = objc2::sel!(setDelegate:);
let set_delegate_fn: unsafe extern "C" fn(
*mut AnyObject,
objc2::runtime::Sel,
*mut AnyObject,
) = std::mem::transmute(crate::ffi::runtime::objc_msgSend as *const c_void);
set_delegate_fn(listener_obj, set_delegate_sel, delegate);
tracing::debug!("Delegate set successfully");
tracing::debug!(
"Calling setSocketListener:forPort: on device {:?} via dispatch queue {:?}",
self.inner,
self.queue
);
struct SetListenerContext {
device: *mut AnyObject,
listener: *mut AnyObject,
port: u32,
}
unsafe impl Send for SetListenerContext {}
unsafe extern "C" fn set_listener_work(ctx: *mut c_void) {
unsafe {
let context = Box::from_raw(ctx as *mut SetListenerContext);
tracing::debug!(
"set_listener_work: device={:?}, listener={:?}, port={}",
context.device,
context.listener,
context.port
);
let set_listener_sel = objc2::sel!(setSocketListener:forPort:);
let set_listener_fn: unsafe extern "C" fn(
*mut AnyObject,
objc2::runtime::Sel,
*mut AnyObject,
u32,
) = std::mem::transmute(crate::ffi::runtime::objc_msgSend as *const c_void);
set_listener_fn(
context.device,
set_listener_sel,
context.listener,
context.port,
);
tracing::debug!("set_listener_work completed");
}
}
let context = Box::new(SetListenerContext {
device: self.inner,
listener: listener_obj,
port,
});
let context_ptr = Box::into_raw(context);
unsafe extern "C" {
fn dispatch_sync_f(
queue: *mut AnyObject,
context: *mut c_void,
work: unsafe extern "C" fn(*mut c_void),
);
}
dispatch_sync_f(self.queue, context_ptr as *mut c_void, set_listener_work);
tracing::debug!("setSocketListener completed");
tracing::info!("Listening on port {} with handle {}", port, handle);
Ok(VirtioSocketListener {
port,
handle,
receiver: rx,
listener_obj,
delegate,
})
}
}
pub fn remove_listener(&self, port: u32) {
tracing::debug!("VirtioSocketDevice::remove_listener(port={})", port);
unsafe {
let set_listener_sel = objc2::sel!(setSocketListener:forPort:);
let set_listener_fn: unsafe extern "C" fn(
*mut AnyObject,
objc2::runtime::Sel,
*mut AnyObject,
u32,
) = std::mem::transmute(crate::ffi::runtime::objc_msgSend as *const c_void);
set_listener_fn(self.inner, set_listener_sel, std::ptr::null_mut(), port);
}
}
}
fn is_transient_connect_error(message: &str) -> bool {
let msg = message.to_ascii_lowercase();
msg.contains("connection reset")
|| msg.contains("connection refused")
|| msg.contains("connection aborted")
|| msg.contains("broken pipe")
}
pub struct VirtioSocketConnection {
fd: RawFd,
source_port: u32,
destination_port: u32,
}
impl VirtioSocketConnection {
#[inline]
#[must_use]
pub fn as_raw_fd(&self) -> RawFd {
self.fd
}
#[inline]
#[must_use]
pub fn source_port(&self) -> u32 {
self.source_port
}
#[inline]
#[must_use]
pub fn destination_port(&self) -> u32 {
self.destination_port
}
pub fn read(&self, buf: &mut [u8]) -> std::io::Result<usize> {
let n = unsafe { libc::read(self.fd, buf.as_mut_ptr() as *mut c_void, buf.len()) };
if n < 0 {
Err(std::io::Error::last_os_error())
} else {
Ok(n as usize)
}
}
pub fn write(&self, buf: &[u8]) -> std::io::Result<usize> {
let n = unsafe { libc::write(self.fd, buf.as_ptr() as *const c_void, buf.len()) };
if n < 0 {
Err(std::io::Error::last_os_error())
} else {
Ok(n as usize)
}
}
#[must_use]
pub fn into_raw_fd(self) -> RawFd {
let fd = self.fd;
std::mem::forget(self);
fd
}
}
impl Drop for VirtioSocketConnection {
fn drop(&mut self) {
if self.fd >= 0 {
unsafe {
libc::close(self.fd);
}
}
}
}
pub struct VirtioSocketListener {
port: u32,
handle: ListenerHandle,
receiver: mpsc::UnboundedReceiver<IncomingConnection>,
listener_obj: *mut AnyObject,
delegate: *mut AnyObject,
}
unsafe impl Send for VirtioSocketListener {}
impl VirtioSocketListener {
#[inline]
#[must_use]
pub fn port(&self) -> u32 {
self.port
}
pub async fn accept(&mut self) -> VZResult<VirtioSocketConnection> {
match self.receiver.recv().await {
Some(incoming) => {
tracing::debug!(
"Accepted connection: fd={}, src={}, dst={}",
incoming.fd,
incoming.source_port,
incoming.destination_port
);
Ok(VirtioSocketConnection {
fd: incoming.fd,
source_port: incoming.source_port,
destination_port: incoming.destination_port,
})
}
None => {
Err(VZError::OperationFailed("Listener closed".into()))
}
}
}
pub fn try_accept(&mut self) -> Option<VirtioSocketConnection> {
match self.receiver.try_recv() {
Ok(incoming) => Some(VirtioSocketConnection {
fd: incoming.fd,
source_port: incoming.source_port,
destination_port: incoming.destination_port,
}),
Err(_) => None,
}
}
}
impl Drop for VirtioSocketListener {
fn drop(&mut self) {
tracing::debug!("Dropping VirtioSocketListener for port {}", self.port);
unregister_listener(self.handle);
if !self.listener_obj.is_null() {
crate::ffi::release(self.listener_obj);
}
if !self.delegate.is_null() {
crate::ffi::release(self.delegate);
}
tracing::debug!("VirtioSocketListener dropped for port {}", self.port);
}
}