#![cfg(all(target_os = "macos", target_arch = "aarch64"))]
use std::fmt;
use std::io::{Read, Write};
use std::net::{Shutdown, TcpListener, TcpStream};
use std::os::fd::{AsRawFd, FromRawFd};
use std::os::unix::net::UnixListener;
use std::os::unix::net::UnixStream;
use std::sync::Arc;
use std::time::{Duration, Instant};
use crate::devices::virtio::vsock::device::Vsock;
use crate::devices::virtio::vsock::mux_profile::{self, Stage};
#[derive(Debug)]
pub enum StartError {
Bind {
frontend: &'static str,
endpoint: String,
source: std::io::Error,
},
LocalAddr {
frontend: &'static str,
endpoint: String,
source: std::io::Error,
},
ThreadSpawn {
name: String,
source: std::io::Error,
},
}
impl fmt::Display for StartError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
StartError::Bind {
frontend,
endpoint,
source,
} => write!(f, "{frontend}: bind {endpoint}: {source}"),
StartError::LocalAddr {
frontend,
endpoint,
source,
} => write!(f, "{frontend}: local_addr {endpoint}: {source}"),
StartError::ThreadSpawn { name, source } => {
write!(f, "spawn thread {name}: {source}")
}
}
}
}
impl std::error::Error for StartError {}
fn lock_socket_perms(sock_path: &str) {
use std::os::unix::fs::PermissionsExt;
if let Err(e) = std::fs::set_permissions(sock_path, std::fs::Permissions::from_mode(0o600)) {
eprintln!(" [vsock] warn: chmod 0600 {sock_path}: {e}");
}
}
pub fn start(sock_path: &str, vsock: Arc<Vsock>, vm_port: Option<u32>) -> Result<(), StartError> {
let _ = std::fs::remove_file(sock_path);
let listener = UnixListener::bind(sock_path).map_err(|source| StartError::Bind {
frontend: "vsock-mux",
endpoint: sock_path.to_string(),
source,
})?;
lock_socket_perms(sock_path);
eprintln!(" vsock-mux on {sock_path} -> guest vm_port={:?}", vm_port);
let name = "vsock-mux-acceptor".to_string();
std::thread::Builder::new()
.name(name.clone())
.spawn(move || {
for stream in listener.incoming() {
let stream = match stream {
Ok(s) => s,
Err(e) => {
eprintln!("[vsock-mux] accept err: {e}");
continue;
}
};
let vsock_c = vsock.clone();
std::thread::Builder::new()
.name("vsock-mux-conn".to_string())
.spawn(move || handle_conn(stream, vsock_c.as_ref(), vm_port))
.ok();
}
})
.map_err(|source| StartError::ThreadSpawn { name, source })?;
Ok(())
}
pub fn start_exec(
sock_path: &str,
vsock: Arc<Vsock>,
guest_port: u32,
) -> Result<(), StartError> {
let _ = std::fs::remove_file(sock_path);
let listener = UnixListener::bind(sock_path).map_err(|source| StartError::Bind {
frontend: "vsock-exec",
endpoint: sock_path.to_string(),
source,
})?;
lock_socket_perms(sock_path);
eprintln!(
" vsock-exec on {sock_path} -> guest port {guest_port} (native AF_VSOCK)"
);
let name = "vsock-exec-acceptor".to_string();
std::thread::Builder::new()
.name(name.clone())
.spawn(move || {
for stream in listener.incoming() {
let stream = match stream {
Ok(s) => s,
Err(e) => {
eprintln!("[vsock-exec] accept err: {e}");
continue;
}
};
let vsock_c = vsock.clone();
std::thread::Builder::new()
.name("vsock-exec-conn".to_string())
.spawn(move || {
if let Err(e) = vsock_c.muxer().open_native_to_guest(
crate::devices::virtio::vsock::muxer_thread::MuxerStream::Unix(stream),
guest_port,
) {
eprintln!("[vsock-exec] open_native_to_guest: {e}");
}
})
.ok();
}
})
.map_err(|source| StartError::ThreadSpawn { name, source })?;
Ok(())
}
pub fn start_handoff(
sock_path: &str,
vsock: Arc<Vsock>,
vm_port: Option<u32>,
) -> Result<(), StartError> {
let _ = std::fs::remove_file(sock_path);
let listener = UnixListener::bind(sock_path).map_err(|source| StartError::Bind {
frontend: "vsock-mux-handoff",
endpoint: sock_path.to_string(),
source,
})?;
lock_socket_perms(sock_path);
eprintln!(
" vsock-mux-handoff on {sock_path} -> guest vm_port={:?}",
vm_port
);
let name = "vsock-mux-handoff-acceptor".to_string();
std::thread::Builder::new()
.name(name.clone())
.spawn(move || {
for stream in listener.incoming() {
let stream = match stream {
Ok(s) => s,
Err(e) => {
eprintln!("[vsock-mux-handoff] accept err: {e}");
continue;
}
};
let vsock_c = vsock.clone();
std::thread::Builder::new()
.name("vsock-mux-handoff-conn".to_string())
.spawn(move || handle_handoff_conn(stream, vsock_c, vm_port))
.ok();
}
})
.map_err(|source| StartError::ThreadSpawn { name, source })?;
Ok(())
}
fn handle_handoff_conn(mut conn: UnixStream, vsock: Arc<Vsock>, vm_port: Option<u32>) {
let _ = wait_for_host_port(&vsock, vm_port);
loop {
match recv_handoff(&mut conn) {
Ok((tcp, prefix)) => {
if let Err(e) = vsock
.muxer()
.open_tcp_to_guest_with_prefix(tcp, prefix, vm_port)
{
eprintln!("[vsock-mux-handoff] open_tcp_to_guest_with_prefix: {e}");
}
}
Err(e) => {
if e.kind() != std::io::ErrorKind::UnexpectedEof {
eprintln!("[vsock-mux-handoff] recv: {e}");
}
return;
}
}
}
}
fn recv_handoff(conn: &mut UnixStream) -> std::io::Result<(TcpStream, Vec<u8>)> {
let mut len_buf = [0u8; 4];
let mut got = 0usize;
while got < 4 {
let n = conn.read(&mut len_buf[got..])?;
if n == 0 {
return Err(std::io::ErrorKind::UnexpectedEof.into());
}
got += n;
}
let prefix_len = u32::from_be_bytes(len_buf) as usize;
if prefix_len > 1 << 20 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("handoff prefix too large: {prefix_len}"),
));
}
if prefix_len == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"handoff prefix_len=0 — sender must include at least 1 byte with the cmsg",
));
}
let mut prefix = vec![0u8; prefix_len];
let mut filled = 0usize;
let mut fd: Option<libc::c_int> = None;
while filled < prefix_len {
let mut iov = libc::iovec {
iov_base: prefix[filled..].as_mut_ptr() as *mut libc::c_void,
iov_len: prefix_len - filled,
};
let cmsg_len = unsafe { libc::CMSG_SPACE(std::mem::size_of::<libc::c_int>() as u32) };
let mut cmsg_buf = vec![0u8; cmsg_len as usize];
let mut msg: libc::msghdr = unsafe { std::mem::zeroed() };
msg.msg_iov = &mut iov as *mut libc::iovec;
msg.msg_iovlen = 1;
msg.msg_control = cmsg_buf.as_mut_ptr() as *mut libc::c_void;
msg.msg_controllen = cmsg_len as _;
let n = unsafe { libc::recvmsg(conn.as_raw_fd(), &mut msg, 0) };
if n < 0 {
let err = std::io::Error::last_os_error();
if err.kind() == std::io::ErrorKind::Interrupted {
continue;
}
return Err(err);
}
if n == 0 {
return Err(std::io::ErrorKind::UnexpectedEof.into());
}
if fd.is_none() {
let mut cmsg_ptr = unsafe { libc::CMSG_FIRSTHDR(&msg) };
while !cmsg_ptr.is_null() {
let cmsg = unsafe { &*cmsg_ptr };
if cmsg.cmsg_level == libc::SOL_SOCKET && cmsg.cmsg_type == libc::SCM_RIGHTS {
let data_ptr = unsafe { libc::CMSG_DATA(cmsg_ptr) } as *const libc::c_int;
let one = unsafe { std::ptr::read_unaligned(data_ptr) };
fd = Some(one);
break;
}
cmsg_ptr = unsafe { libc::CMSG_NXTHDR(&msg, cmsg_ptr) };
}
}
filled += n as usize;
}
let Some(fd) = fd else {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"handoff: SCM_RIGHTS cmsg missing",
));
};
let tcp = unsafe { TcpStream::from_raw_fd(fd) };
let _ = tcp.set_nodelay(true);
Ok((tcp, prefix))
}
pub fn start_tcp(addr: &str, vsock: Arc<Vsock>, vm_port: Option<u32>) -> Result<(), StartError> {
let listener = TcpListener::bind(addr).map_err(|source| StartError::Bind {
frontend: "http-port",
endpoint: addr.to_string(),
source,
})?;
let local = listener
.local_addr()
.map_err(|source| StartError::LocalAddr {
frontend: "http-port",
endpoint: addr.to_string(),
source,
})?;
eprintln!(" http-port on {local} -> guest vm_port={vm_port:?}");
let name = "http-port-acceptor".to_string();
std::thread::Builder::new()
.name(name.clone())
.spawn(move || {
for stream in listener.incoming() {
let stream = match stream {
Ok(s) => s,
Err(e) => {
eprintln!("[http-port] accept err: {e}");
continue;
}
};
let _ = stream.set_nodelay(true);
let vsock_c = vsock.clone();
std::thread::Builder::new()
.name("http-port-conn".to_string())
.spawn(move || handle_tcp_conn(stream, vsock_c, vm_port))
.ok();
}
})
.map_err(|source| StartError::ThreadSpawn { name, source })?;
Ok(())
}
fn handle_tcp_conn(client: TcpStream, vsock: Arc<Vsock>, vm_port: Option<u32>) {
match wait_for_host_port(&vsock, vm_port) {
Some(_) => {}
None => return,
}
if let Err(e) = vsock.muxer().open_tcp_to_guest(client, vm_port) {
eprintln!("[http-port] direct tcp->guest failed after listener ready: {e}");
}
}
fn handle_conn(client: UnixStream, vsock: &Vsock, vm_port: Option<u32>) {
match wait_for_host_port(vsock, vm_port) {
Some(_) => {}
None => {
eprintln!("[vsock-mux] no host port (vm_port={vm_port:?})");
return;
}
};
if let Err(e) = vsock.muxer().open_unix_to_guest(client, vm_port) {
eprintln!("[vsock-mux] direct unix->guest failed after listener ready: {e}");
}
}
fn connect_loopback(port: u16) -> std::io::Result<TcpStream> {
TcpStream::connect(("127.0.0.1", port)).or_else(|_| TcpStream::connect(("::1", port)))
}
fn pump_tcp_to_tcp_shutdown(mut r: TcpStream, mut w: TcpStream) {
pump_bytes(&mut r, &mut w);
let _ = w.shutdown(Shutdown::Write);
}
fn pump_tcp_to_unix_shutdown(mut r: TcpStream, mut w: UnixStream) {
pump_bytes(&mut r, &mut w);
let _ = w.shutdown(Shutdown::Write);
}
fn pump_unix_to_tcp_shutdown(mut r: UnixStream, mut w: TcpStream) {
pump_bytes(&mut r, &mut w);
let _ = w.shutdown(Shutdown::Write);
}
fn pump_unix_tcp_poll(mut unix: UnixStream, mut tcp: TcpStream) {
let _ = unix.set_nonblocking(true);
let _ = tcp.set_nonblocking(true);
let ufd = unix.as_raw_fd();
let tfd = tcp.as_raw_fd();
let mut u2t = [0u8; 16 * 1024];
let mut t2u = [0u8; 16 * 1024];
loop {
let mut pfds = [
libc::pollfd {
fd: ufd,
events: libc::POLLIN,
revents: 0,
},
libc::pollfd {
fd: tfd,
events: libc::POLLIN,
revents: 0,
},
];
let rc = unsafe { libc::poll(pfds.as_mut_ptr(), pfds.len() as _, -1) };
if rc < 0 {
let err = std::io::Error::last_os_error();
if err.raw_os_error() == Some(libc::EINTR) {
continue;
}
break;
}
let drain_mask = libc::POLLIN | libc::POLLERR | libc::POLLHUP;
if pfds[0].revents & drain_mask != 0 {
loop {
let t0 = Instant::now();
match unix.read(&mut u2t) {
Ok(0) => {
mux_profile::record(
Stage::FrontendUnixRead,
0,
t0.elapsed().as_micros() as u64,
);
let _ = tcp.shutdown(Shutdown::Write);
return;
}
Ok(n) => {
mux_profile::record(
Stage::FrontendUnixRead,
n,
t0.elapsed().as_micros() as u64,
);
let t1 = Instant::now();
if write_all_poll(tfd, &mut tcp, &u2t[..n]).is_err() {
return;
}
mux_profile::record(
Stage::FrontendUnixToTcpWrite,
n,
t1.elapsed().as_micros() as u64,
);
}
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
Err(_) => return,
}
}
}
if pfds[1].revents & drain_mask != 0 {
loop {
let t0 = Instant::now();
match tcp.read(&mut t2u) {
Ok(0) => {
mux_profile::record(
Stage::FrontendTcpRead,
0,
t0.elapsed().as_micros() as u64,
);
let _ = unix.shutdown(Shutdown::Write);
return;
}
Ok(n) => {
mux_profile::record(
Stage::FrontendTcpRead,
n,
t0.elapsed().as_micros() as u64,
);
let t1 = Instant::now();
if write_all_poll(ufd, &mut unix, &t2u[..n]).is_err() {
return;
}
mux_profile::record(
Stage::FrontendTcpToUnixWrite,
n,
t1.elapsed().as_micros() as u64,
);
}
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
Err(_) => return,
}
}
}
}
}
fn write_all_poll<W: Write>(fd: libc::c_int, w: &mut W, mut buf: &[u8]) -> std::io::Result<()> {
while !buf.is_empty() {
match w.write(buf) {
Ok(0) => return Err(std::io::ErrorKind::WriteZero.into()),
Ok(n) => buf = &buf[n..],
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
wait_writable(fd)?;
}
Err(e) => return Err(e),
}
}
Ok(())
}
fn wait_writable(fd: libc::c_int) -> std::io::Result<()> {
let mut pfd = libc::pollfd {
fd,
events: libc::POLLOUT,
revents: 0,
};
loop {
let rc = unsafe { libc::poll(&mut pfd, 1, -1) };
if rc >= 0 {
return Ok(());
}
let err = std::io::Error::last_os_error();
if err.raw_os_error() != Some(libc::EINTR) {
return Err(err);
}
}
}
fn pump_bytes_owned<R: Read, W: Write>(mut r: R, mut w: W) {
pump_bytes(&mut r, &mut w);
}
fn pump_bytes<R: Read, W: Write>(r: &mut R, w: &mut W) {
let mut buf = [0u8; 16 * 1024];
loop {
match r.read(&mut buf) {
Ok(0) => break,
Ok(n) => {
if w.write_all(&buf[..n]).is_err() {
break;
}
}
Err(_) => break,
}
}
}
fn wait_for_host_port(vsock: &Vsock, vm_port: Option<u32>) -> Option<u16> {
let lookup = || match vm_port {
Some(p) => vsock.muxer().host_port_for_vm_port(p),
None => vsock.muxer().first_host_port(),
};
for _ in 0..50 {
if let Some(p) = lookup() {
return Some(p);
}
std::thread::sleep(Duration::from_millis(20));
}
lookup()
}