#![allow(dead_code)]
use std::collections::HashMap;
use std::net::{SocketAddr, TcpStream, UdpSocket};
use std::os::unix::net::UnixStream;
use std::sync::{mpsc, Arc, Mutex};
use std::time::Instant;
use super::mux_profile::{self, Stage};
use super::muxer_rxq::MuxerRxQ;
use super::muxer_thread::{self, MuxerCmd, MuxerStream};
use super::packet::{
Header, RxPacket, VSOCK_HOST_CID, VSOCK_OP_CREDIT_REQUEST, VSOCK_OP_CREDIT_UPDATE,
VSOCK_OP_REQUEST, VSOCK_OP_RESPONSE, VSOCK_OP_RST, VSOCK_OP_RW, VSOCK_OP_SHUTDOWN,
VSOCK_TYPE_DGRAM, VSOCK_TYPE_STREAM,
};
use super::proxy::{proxy_key, Proxy};
use super::tsi_stream::TsiListener;
pub const TSI_PROXY_PORT: u32 = 620;
pub const VSOCK_ENV_PORT: u32 = 1026;
pub const TSI_PROXY_CREATE: u32 = 1024;
pub const TSI_CONNECT: u32 = 1025;
pub const TSI_GETNAME: u32 = 1026;
pub const TSI_SENDTO_ADDR: u32 = 1027;
pub const TSI_SENDTO_DATA: u32 = 1028;
pub const TSI_LISTEN: u32 = 1029;
pub const TSI_ACCEPT: u32 = 1030;
pub const TSI_PROXY_RELEASE: u32 = 1031;
#[derive(Clone, Debug)]
pub struct TsiListenerSnapshot {
pub cid: u64,
pub peer_port: u32,
pub vm_port: u32,
pub family: u16,
pub socktype: u16,
}
struct TsiState {
family: u16,
socktype: u16,
vm_port: Option<u32>,
listener: Option<TsiListener>,
}
pub struct PendingInbound {
pub cid: u64,
pub host_src_port: u32,
pub vm_port: u32,
pub stream: Option<MuxerStream>,
}
pub struct PendingOutbound {
pub cid: u64,
pub peer_port: u32,
pub tcp: Option<TcpStream>,
}
pub struct InboundState {
pub our_fwd_cnt: u32,
pub last_credit_fwd_cnt: u32,
pub guest_dst_port: u32,
}
pub const OUR_BUF_ALLOC: u32 = u32::MAX / 2;
const CREDIT_UPDATE_THRESHOLD: u32 = OUR_BUF_ALLOC / 2;
fn body_le_u16(body: &[u8], offset: usize) -> Option<u16> {
let bytes = body.get(offset..offset + 2)?.try_into().ok()?;
Some(u16::from_le_bytes(bytes))
}
fn body_be_u16(body: &[u8], offset: usize) -> Option<u16> {
let bytes = body.get(offset..offset + 2)?.try_into().ok()?;
Some(u16::from_be_bytes(bytes))
}
fn body_le_u32(body: &[u8], offset: usize) -> Option<u32> {
let bytes = body.get(offset..offset + 4)?.try_into().ok()?;
Some(u32::from_le_bytes(bytes))
}
pub static VSOCK_ENV_JSON: Mutex<Option<String>> = Mutex::new(None);
pub fn set_env_json(json: String) {
*VSOCK_ENV_JSON.lock().unwrap() = Some(json);
}
#[inline]
pub fn vsock_trace_enabled() -> bool {
use std::sync::atomic::{AtomicU8, Ordering};
static CACHED: AtomicU8 = AtomicU8::new(0);
let v = CACHED.load(Ordering::Relaxed);
if v != 0 {
return v == 2;
}
let on = std::env::var_os("SUPERMACHINE_VSOCK_TRACE").is_some();
CACHED.store(if on { 2 } else { 1 }, Ordering::Relaxed);
on
}
pub struct VsockMuxer {
cid: u64,
proxies: Mutex<HashMap<u64, Box<dyn Proxy>>>,
pending_inbound: Arc<Mutex<HashMap<(u64, u32), PendingInbound>>>,
pending_outbound: Mutex<HashMap<(u64, u32), PendingOutbound>>,
udp_dst: Mutex<HashMap<(u64, u32), SocketAddr>>,
udp_sockets: Mutex<HashMap<(u64, u32), Arc<UdpSocket>>>,
inbound_conns: Mutex<HashMap<(u64, u32), InboundState>>,
io_tx: mpsc::Sender<MuxerCmd>,
io_waker: Arc<mio::Waker>,
tsi: Mutex<HashMap<(u64, u32), TsiState>>,
pub rxq: Arc<Mutex<MuxerRxQ>>,
kick_device: Mutex<Option<Arc<dyn Fn() + Send + Sync>>>,
pub(crate) tsi_token: Option<[u8; TSI_TOKEN_LEN]>,
}
pub const TSI_TOKEN_LEN: usize = 32;
#[inline]
fn ct_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff: u8 = 0;
for i in 0..a.len() {
diff |= a[i] ^ b[i];
}
diff == 0
}
impl VsockMuxer {
pub fn new(cid: u64) -> Result<Self, muxer_thread::StartError> {
Self::with_tsi_token(cid, None)
}
pub fn with_tsi_token(
cid: u64,
token: Option<[u8; TSI_TOKEN_LEN]>,
) -> Result<Self, muxer_thread::StartError> {
let (io_tx, io_waker) = muxer_thread::spawn()?;
Ok(Self {
cid,
proxies: Mutex::new(HashMap::new()),
pending_inbound: Arc::new(Mutex::new(HashMap::new())),
pending_outbound: Mutex::new(HashMap::new()),
udp_dst: Mutex::new(HashMap::new()),
udp_sockets: Mutex::new(HashMap::new()),
inbound_conns: Mutex::new(HashMap::new()),
tsi: Mutex::new(HashMap::new()),
rxq: Arc::new(Mutex::new(MuxerRxQ::new())),
kick_device: Mutex::new(None),
io_tx,
io_waker,
tsi_token: token,
})
}
pub fn set_kick(&self, kick: Arc<dyn Fn() + Send + Sync>) {
*self.kick_device.lock().unwrap() = Some(kick);
}
pub fn is_transport_idle(&self) -> bool {
self.pending_inbound.lock().unwrap().is_empty()
&& self.pending_outbound.lock().unwrap().is_empty()
&& self.inbound_conns.lock().unwrap().is_empty()
&& self.rxq.lock().unwrap().is_empty()
}
pub fn reset(&self) {
self.tsi.lock().unwrap().clear();
let (done_tx, done_rx) = mpsc::channel();
if self.io_tx.send(MuxerCmd::Reset { done: done_tx }).is_ok() {
let _ = self.io_waker.wake();
let _ = done_rx.recv_timeout(std::time::Duration::from_millis(50));
}
self.proxies.lock().unwrap().clear();
self.pending_inbound.lock().unwrap().clear();
self.pending_outbound.lock().unwrap().clear();
self.udp_dst.lock().unwrap().clear();
self.udp_sockets.lock().unwrap().clear();
self.inbound_conns.lock().unwrap().clear();
self.rxq.lock().unwrap().drain();
}
pub fn host_port_for_vm_port(&self, vm_port: u32) -> Option<u16> {
let s = self.tsi.lock().unwrap();
s.values().find_map(|st| {
if st.vm_port == Some(vm_port) {
st.listener.as_ref().map(|l| l.host_addr.port())
} else {
None
}
})
}
pub fn first_host_port(&self) -> Option<u16> {
let s = self.tsi.lock().unwrap();
s.values()
.find_map(|st| {
if st.family == libc::AF_INET as u16 {
st.listener.as_ref().map(|l| l.host_addr.port())
} else {
None
}
})
.or_else(|| {
s.values()
.find_map(|st| st.listener.as_ref().map(|l| l.host_addr.port()))
})
}
fn first_listener_endpoint(&self, vm_port: Option<u32>) -> Option<(u64, u32)> {
let s = self.tsi.lock().unwrap();
let endpoint = |cid: u64, st: &TsiState| {
let listener_vm_port = st.vm_port?;
if vm_port.is_some_and(|want| want != listener_vm_port) {
return None;
}
Some((cid, listener_vm_port))
};
s.iter()
.find_map(|((cid, _), st)| {
if st.family == libc::AF_INET as u16 {
endpoint(*cid, st)
} else {
None
}
})
.or_else(|| s.iter().find_map(|((cid, _), st)| endpoint(*cid, st)))
}
fn open_stream_to_guest(
&self,
stream: MuxerStream,
vm_port: Option<u32>,
) -> std::io::Result<()> {
let Some((cid, vm_port)) = self.first_listener_endpoint(vm_port) else {
return Err(std::io::Error::new(
std::io::ErrorKind::NotConnected,
"no TSI listener",
));
};
let host_src_port = crate::devices::virtio::vsock::tsi_stream::alloc_host_src_port();
self.pending_inbound.lock().unwrap().insert(
(cid, host_src_port),
PendingInbound {
cid,
host_src_port,
vm_port,
stream: Some(stream),
},
);
let req = RxPacket {
hdr: Header {
src_cid: VSOCK_HOST_CID,
dst_cid: cid,
src_port: host_src_port,
dst_port: vm_port,
len: 0,
type_: VSOCK_TYPE_STREAM,
op: VSOCK_OP_REQUEST,
flags: 0,
buf_alloc: OUR_BUF_ALLOC,
fwd_cnt: 0,
},
data: Vec::new(),
};
let kick = self.kick_device.lock().unwrap().clone();
push_rxq_and_kick(&self.rxq, &kick, req);
Ok(())
}
pub fn open_unix_to_guest(
&self,
unix: UnixStream,
vm_port: Option<u32>,
) -> std::io::Result<()> {
self.open_stream_to_guest(MuxerStream::Unix(unix), vm_port)
}
pub fn open_native_to_guest(
&self,
stream: MuxerStream,
guest_port: u32,
) -> std::io::Result<()> {
let cid = self.cid;
let host_src_port = crate::devices::virtio::vsock::tsi_stream::alloc_host_src_port();
self.pending_inbound.lock().unwrap().insert(
(cid, host_src_port),
PendingInbound {
cid,
host_src_port,
vm_port: guest_port,
stream: Some(stream),
},
);
let req = RxPacket {
hdr: Header {
src_cid: VSOCK_HOST_CID,
dst_cid: cid,
src_port: host_src_port,
dst_port: guest_port,
len: 0,
type_: VSOCK_TYPE_STREAM,
op: VSOCK_OP_REQUEST,
flags: 0,
buf_alloc: OUR_BUF_ALLOC,
fwd_cnt: 0,
},
data: Vec::new(),
};
let kick = self.kick_device.lock().unwrap().clone();
push_rxq_and_kick(&self.rxq, &kick, req);
Ok(())
}
pub fn open_tcp_to_guest(&self, tcp: TcpStream, vm_port: Option<u32>) -> std::io::Result<()> {
self.open_stream_to_guest(MuxerStream::Tcp(tcp), vm_port)
}
pub fn open_tcp_to_guest_with_prefix(
&self,
tcp: TcpStream,
prefix: Vec<u8>,
vm_port: Option<u32>,
) -> std::io::Result<()> {
self.open_stream_to_guest(MuxerStream::TcpWithPrefix(tcp, prefix), vm_port)
}
pub fn capture_tsi_listeners(&self) -> Vec<TsiListenerSnapshot> {
let s = self.tsi.lock().unwrap();
s.iter()
.filter_map(|((cid, peer_port), st)| {
st.vm_port.map(|vm_port| TsiListenerSnapshot {
cid: *cid,
peer_port: *peer_port,
vm_port,
family: st.family,
socktype: st.socktype,
})
})
.collect()
}
pub fn listener_count(&self) -> usize {
let s = self.tsi.lock().unwrap();
s.values().filter(|st| st.vm_port.is_some()).count()
}
pub fn restore_tsi_listeners(self: &Arc<Self>, snaps: &[TsiListenerSnapshot]) {
for s in snaps {
let pending = self.pending_inbound.clone();
let rxq = self.rxq.clone();
let kick = self.kick_device.lock().unwrap().clone();
let on_accept: Arc<dyn Fn(u64, u32, u32, std::net::TcpStream) + Send + Sync> =
Arc::new(move |cid, host_src_port, vm_port, tcp| {
pending.lock().unwrap().insert(
(cid, host_src_port),
PendingInbound {
cid,
host_src_port,
vm_port,
stream: Some(MuxerStream::Tcp(tcp)),
},
);
let req = RxPacket {
hdr: Header {
src_cid: VSOCK_HOST_CID,
dst_cid: cid,
src_port: host_src_port,
dst_port: vm_port,
len: 0,
type_: VSOCK_TYPE_STREAM,
op: VSOCK_OP_REQUEST,
flags: 0,
buf_alloc: OUR_BUF_ALLOC,
fwd_cnt: 0,
},
data: Vec::new(),
};
push_rxq_and_kick(&rxq, &kick, req);
});
match TsiListener::bind(s.cid, s.vm_port, on_accept) {
Ok(listener) => {
if vsock_trace_enabled() {
eprintln!(
"[muxer] restored TSI listener cid={} vm_port={} -> host {}",
s.cid, s.vm_port, listener.host_addr
);
}
self.tsi.lock().unwrap().insert(
(s.cid, s.peer_port),
TsiState {
family: s.family,
socktype: s.socktype,
vm_port: Some(s.vm_port),
listener: Some(listener),
},
);
}
Err(e) => {
eprintln!(
"[muxer] restore listener cid={} vm_port={} ERR: {e}",
s.cid, s.vm_port
);
}
}
}
}
fn pending_inbound_arc(&self) -> Arc<Mutex<HashMap<(u64, u32), PendingInbound>>> {
self.pending_inbound.clone()
}
pub fn submit(&self, pkt: RxPacket) {
let kick = self.kick_device.lock().unwrap().clone();
push_rxq_and_kick(&self.rxq, &kick, pkt);
}
pub fn handle_tx(&self, hdr: &Header, payload: &[u8]) -> Vec<RxPacket> {
if hdr.type_ == VSOCK_TYPE_DGRAM
&& hdr.dst_port >= TSI_PROXY_CREATE
&& hdr.dst_port <= TSI_PROXY_RELEASE
{
return self.handle_tsi_control(hdr, payload);
}
if hdr.type_ == VSOCK_TYPE_DGRAM && hdr.op == VSOCK_OP_RW {
self.send_udp_payload(hdr.src_cid, hdr.dst_port, hdr.src_port, payload);
return Vec::new();
}
match hdr.op {
VSOCK_OP_REQUEST => self.handle_request(hdr),
VSOCK_OP_RESPONSE => self.handle_response(hdr),
VSOCK_OP_RW => self.handle_rw(hdr, payload),
VSOCK_OP_SHUTDOWN | VSOCK_OP_RST => self.handle_close(hdr),
VSOCK_OP_CREDIT_UPDATE | VSOCK_OP_CREDIT_REQUEST => self.handle_credit(hdr),
_ => Vec::new(),
}
}
fn handle_request(&self, hdr: &Header) -> Vec<RxPacket> {
if hdr.dst_port == VSOCK_ENV_PORT {
let json = VSOCK_ENV_JSON
.lock()
.unwrap()
.clone()
.unwrap_or_else(|| r#"{"env":{},"secrets":{}}"#.into());
let mk = |op: u16, flags: u32, data: Vec<u8>| RxPacket {
hdr: Header {
src_cid: hdr.dst_cid,
dst_cid: hdr.src_cid,
src_port: hdr.dst_port,
dst_port: hdr.src_port,
len: data.len() as u32,
type_: VSOCK_TYPE_STREAM,
op,
flags,
buf_alloc: OUR_BUF_ALLOC,
fwd_cnt: 0,
},
data,
};
return vec![
mk(VSOCK_OP_RESPONSE, 0, Vec::new()),
mk(VSOCK_OP_RW, 0, json.into_bytes()),
mk(VSOCK_OP_SHUTDOWN, 1u32 | 2u32, Vec::new()),
];
}
if hdr.dst_port == TSI_PROXY_PORT {
let key = (hdr.src_cid, hdr.src_port);
let mut pending = self.pending_outbound.lock().unwrap();
if let Some(mut po) = pending.remove(&key) {
let tcp = match po.tcp.take() {
Some(t) => t,
None => return vec![RxPacket::rst_for(hdr)],
};
drop(pending);
let host_src_port = hdr.src_port; let guest_dst_port = hdr.src_port;
let rxq = self.rxq.clone();
let kick = self.kick_device.lock().unwrap().clone();
let cid = po.cid;
let on_data: Arc<dyn Fn(Vec<u8>) + Send + Sync> = Arc::new(move |data| {
let (op, flags) = if data.is_empty() {
(VSOCK_OP_SHUTDOWN, 1u32 | 2u32)
} else {
(VSOCK_OP_RW, 0u32)
};
let pkt = RxPacket {
hdr: Header {
src_cid: VSOCK_HOST_CID,
dst_cid: cid,
src_port: TSI_PROXY_PORT,
dst_port: guest_dst_port,
len: data.len() as u32,
type_: VSOCK_TYPE_STREAM,
op,
flags,
buf_alloc: OUR_BUF_ALLOC,
fwd_cnt: 0,
},
data,
};
push_rxq_and_kick(&rxq, &kick, pkt);
});
let _ = self.io_tx.send(MuxerCmd::Register {
host_src_port,
stream: MuxerStream::Tcp(tcp),
on_data,
});
let _ = self.io_waker.wake();
self.inbound_conns.lock().unwrap().insert(
(cid, host_src_port),
InboundState {
our_fwd_cnt: 0,
last_credit_fwd_cnt: 0,
guest_dst_port,
},
);
if vsock_trace_enabled() {
eprintln!(
"[muxer] outbound conn established cid={cid} peer_port={host_src_port}"
);
}
return vec![RxPacket {
hdr: Header {
src_cid: VSOCK_HOST_CID,
dst_cid: cid,
src_port: TSI_PROXY_PORT,
dst_port: hdr.src_port,
len: 0,
type_: VSOCK_TYPE_STREAM,
op: VSOCK_OP_RESPONSE,
flags: 0,
buf_alloc: OUR_BUF_ALLOC,
fwd_cnt: 0,
},
data: Vec::new(),
}];
}
return vec![RxPacket::rst_for(hdr)];
}
vec![RxPacket::rst_for(hdr)]
}
fn handle_response(&self, hdr: &Header) -> Vec<RxPacket> {
let key = (hdr.src_cid, hdr.dst_port);
let pending = self.pending_inbound.lock().unwrap().remove(&key);
if let Some(mut pi) = pending {
let stream = match pi.stream.take() {
Some(s) => s,
None => return vec![RxPacket::rst_for(hdr)],
};
let rxq = self.rxq.clone();
let kick = self.kick_device.lock().unwrap().clone();
let cid = pi.cid;
let host_src_port = pi.host_src_port;
let guest_dst_port = hdr.src_port;
let on_data: Arc<dyn Fn(Vec<u8>) + Send + Sync> = Arc::new(move |data| {
let (op, flags) = if data.is_empty() {
(VSOCK_OP_SHUTDOWN, 1u32 | 2u32)
} else {
(VSOCK_OP_RW, 0u32)
};
let pkt = RxPacket {
hdr: Header {
src_cid: VSOCK_HOST_CID,
dst_cid: cid,
src_port: host_src_port,
dst_port: guest_dst_port,
len: data.len() as u32,
type_: VSOCK_TYPE_STREAM,
op,
flags,
buf_alloc: OUR_BUF_ALLOC,
fwd_cnt: 0,
},
data,
};
push_rxq_and_kick(&rxq, &kick, pkt);
});
let _ = self.io_tx.send(MuxerCmd::Register {
host_src_port,
stream,
on_data,
});
let _ = self.io_waker.wake();
self.inbound_conns.lock().unwrap().insert(
key,
InboundState {
our_fwd_cnt: 0,
last_credit_fwd_cnt: 0,
guest_dst_port,
},
);
if vsock_trace_enabled() {
eprintln!("[muxer] inbound conn established cid={cid} host_src={host_src_port} guest_dst={guest_dst_port}");
}
}
Vec::new()
}
fn handle_rw(&self, hdr: &Header, payload: &[u8]) -> Vec<RxPacket> {
let conn_key = if hdr.dst_port == TSI_PROXY_PORT {
(hdr.src_cid, hdr.src_port)
} else {
(hdr.src_cid, hdr.dst_port)
};
let needs_credit_update = {
let mut conns = self.inbound_conns.lock().unwrap();
if let Some(st) = conns.get_mut(&conn_key) {
st.our_fwd_cnt = st.our_fwd_cnt.wrapping_add(payload.len() as u32);
let consumed_since_last = st.our_fwd_cnt.wrapping_sub(st.last_credit_fwd_cnt);
if vsock_trace_enabled()
&& st.our_fwd_cnt % 50000 < 1500
{
eprintln!(
"[muxer-rw-fc] our_fwd_cnt={} last_credit={} consumed_since_last={} thr={}",
st.our_fwd_cnt,
st.last_credit_fwd_cnt,
consumed_since_last,
CREDIT_UPDATE_THRESHOLD
);
}
if consumed_since_last >= CREDIT_UPDATE_THRESHOLD {
let fwd_cnt = st.our_fwd_cnt;
let guest_dst = st.guest_dst_port;
st.last_credit_fwd_cnt = fwd_cnt;
Some((fwd_cnt, guest_dst))
} else {
None
}
} else {
drop(conns);
let key = proxy_key(hdr.dst_port, hdr.src_port);
let mut proxies = self.proxies.lock().unwrap();
if let Some(p) = proxies.get_mut(&key) {
return p.handle_packet(hdr, payload);
}
return vec![RxPacket::rst_for(hdr)];
}
};
if !payload.is_empty() {
let t0 = Instant::now();
let _ = self.io_tx.send(MuxerCmd::Write {
host_src_port: conn_key.1,
bytes: payload.to_vec(),
});
let _ = self.io_waker.wake();
mux_profile::record(
Stage::GuestToTcpSend,
payload.len(),
t0.elapsed().as_micros() as u64,
);
}
if let Some((fwd_cnt, guest_dst_port)) = needs_credit_update {
return vec![RxPacket {
hdr: Header {
src_cid: VSOCK_HOST_CID,
dst_cid: hdr.src_cid,
src_port: hdr.dst_port,
dst_port: guest_dst_port,
len: 0,
type_: VSOCK_TYPE_STREAM,
op: VSOCK_OP_CREDIT_UPDATE,
flags: 0,
buf_alloc: OUR_BUF_ALLOC,
fwd_cnt,
},
data: Vec::new(),
}];
}
Vec::new()
}
fn handle_close(&self, hdr: &Header) -> Vec<RxPacket> {
let key = proxy_key(hdr.dst_port, hdr.src_port);
self.proxies.lock().unwrap().remove(&key);
let conn_key = if hdr.dst_port == TSI_PROXY_PORT {
(hdr.src_cid, hdr.src_port)
} else {
(hdr.src_cid, hdr.dst_port)
};
if self
.inbound_conns
.lock()
.unwrap()
.remove(&conn_key)
.is_some()
{
let _ = self.io_tx.send(MuxerCmd::Close {
host_src_port: conn_key.1,
});
let _ = self.io_waker.wake();
}
Vec::new()
}
fn handle_credit(&self, _hdr: &Header) -> Vec<RxPacket> {
Vec::new()
}
fn handle_tsi_control(&self, hdr: &Header, body: &[u8]) -> Vec<RxPacket> {
let body: &[u8] = if let Some(expected) = self.tsi_token.as_ref() {
if body.len() < TSI_TOKEN_LEN || !ct_eq(&body[..TSI_TOKEN_LEN], &expected[..]) {
return Vec::new();
}
&body[TSI_TOKEN_LEN..]
} else {
body
};
let mk_resp = |result: i32| -> RxPacket {
let bytes = result.to_le_bytes().to_vec();
RxPacket {
hdr: Header {
src_cid: hdr.dst_cid,
dst_cid: hdr.src_cid,
src_port: hdr.dst_port,
dst_port: hdr.src_port,
len: bytes.len() as u32,
type_: VSOCK_TYPE_DGRAM,
op: VSOCK_OP_RW,
flags: 0,
buf_alloc: OUR_BUF_ALLOC,
fwd_cnt: 0,
},
data: bytes,
}
};
match hdr.dst_port {
TSI_PROXY_CREATE if body.len() >= 8 => {
let (Some(peer_port), Some(family), Some(socktype)) = (
body_le_u32(body, 0),
body_le_u16(body, 4),
body_le_u16(body, 6),
) else {
return Vec::new();
};
eprintln!(
"[muxer] PROXY_CREATE peer_port={peer_port} family={family} type={socktype}"
);
self.tsi.lock().unwrap().insert(
(hdr.src_cid, peer_port),
TsiState {
family,
socktype,
vm_port: None,
listener: None,
},
);
Vec::new()
}
TSI_LISTEN if body.len() >= 8 => {
let (Some(peer_port), Some(vm_port)) = (body_le_u32(body, 0), body_le_u32(body, 4))
else {
return vec![mk_resp(-22)];
};
let pending = self.pending_inbound_arc();
let rxq = self.rxq.clone();
let kick = self.kick_device.lock().unwrap().clone();
let on_accept: Arc<dyn Fn(u64, u32, u32, TcpStream) + Send + Sync> =
Arc::new(move |cid, host_src_port, vm_port, tcp| {
pending.lock().unwrap().insert(
(cid, host_src_port),
PendingInbound {
cid,
host_src_port,
vm_port,
stream: Some(MuxerStream::Tcp(tcp)),
},
);
let req = RxPacket {
hdr: Header {
src_cid: VSOCK_HOST_CID,
dst_cid: cid,
src_port: host_src_port,
dst_port: vm_port,
len: 0,
type_: VSOCK_TYPE_STREAM,
op: VSOCK_OP_REQUEST,
flags: 0,
buf_alloc: OUR_BUF_ALLOC,
fwd_cnt: 0,
},
data: Vec::new(),
};
push_rxq_and_kick(&rxq, &kick, req);
});
let listener = TsiListener::bind(hdr.src_cid, vm_port, on_accept).ok();
let host = listener.as_ref().map(|l| l.host_addr);
eprintln!(
"[muxer] LISTEN peer_port={peer_port} vm_port={vm_port} -> host {host:?}"
);
if let Some(s) = self.tsi.lock().unwrap().get_mut(&(hdr.src_cid, peer_port)) {
s.vm_port = Some(vm_port);
s.listener = listener;
}
vec![mk_resp(if host.is_some() { 0 } else { -22 })]
}
TSI_ACCEPT if body.len() >= 4 => {
let Some(peer_port) = body_le_u32(body, 0) else {
return vec![mk_resp(-22)];
};
let s = self.tsi.lock().unwrap();
let r = if s.contains_key(&(hdr.src_cid, peer_port)) {
0
} else {
-22
};
vec![mk_resp(r)]
}
TSI_GETNAME if body.len() >= 12 => {
let mut buf = vec![0u8; 4 + 4 + 128];
buf[0..4].copy_from_slice(&0i32.to_le_bytes()); buf[4..8].copy_from_slice(&16u32.to_le_bytes()); buf[8..10].copy_from_slice(&2u16.to_le_bytes()); buf[12..16].copy_from_slice(&[127, 0, 0, 1]); vec![RxPacket {
hdr: Header {
src_cid: hdr.dst_cid,
dst_cid: hdr.src_cid,
src_port: hdr.dst_port,
dst_port: hdr.src_port,
len: buf.len() as u32,
type_: VSOCK_TYPE_DGRAM,
op: VSOCK_OP_RW,
flags: 0,
buf_alloc: OUR_BUF_ALLOC,
fwd_cnt: 0,
},
data: buf,
}]
}
TSI_CONNECT if body.len() >= 8 => {
let (Some(peer_port), Some(addr_len)) =
(body_le_u32(body, 0), body_le_u32(body, 4))
else {
return vec![mk_resp(-22)];
};
let addr_len = addr_len as usize;
if body.len() < 8 + addr_len.min(128) || addr_len < 16 {
return vec![mk_resp(-22)]; }
let Some(family) = body_le_u16(body, 8) else {
return vec![mk_resp(-22)];
};
if family != 2 {
return vec![mk_resp(-97)]; }
let Some(port) = body_be_u16(body, 10) else {
return vec![mk_resp(-22)];
};
let ip = std::net::Ipv4Addr::new(body[12], body[13], body[14], body[15]);
let target = std::net::SocketAddr::from((ip, port));
let socktype = self
.tsi
.lock()
.unwrap()
.get(&(hdr.src_cid, peer_port))
.map(|s| s.socktype);
if socktype == Some(libc::SOCK_DGRAM as u16) {
if let Err(why) = crate::vmm::egress_policy::check_addr(target) {
eprintln!("[muxer] UDP CONNECT cid={} peer_port={peer_port} -> {target} BLOCKED: {why}",
hdr.src_cid);
return vec![mk_resp(-13)]; }
self.udp_dst
.lock()
.unwrap()
.insert((hdr.src_cid, peer_port), target);
eprintln!(
"[muxer] UDP CONNECT cid={} peer_port={peer_port} -> {target} OK",
hdr.src_cid
);
return vec![mk_resp(0)];
}
if let Err(why) = crate::vmm::egress_policy::check_addr(target) {
eprintln!(
"[muxer] CONNECT cid={} peer_port={peer_port} -> {target} BLOCKED: {why}",
hdr.src_cid
);
return vec![mk_resp(-13)]; }
let res = std::net::TcpStream::connect_timeout(
&target,
std::time::Duration::from_secs(2),
);
match res {
Ok(tcp) => {
let _ = tcp.set_nodelay(true);
let _ = tcp.set_nonblocking(true);
eprintln!(
"[muxer] CONNECT cid={} peer_port={peer_port} -> {target} OK",
hdr.src_cid
);
self.pending_outbound.lock().unwrap().insert(
(hdr.src_cid, peer_port),
PendingOutbound {
cid: hdr.src_cid,
peer_port,
tcp: Some(tcp),
},
);
vec![mk_resp(0)]
}
Err(e) => {
eprintln!(
"[muxer] CONNECT cid={} peer_port={peer_port} -> {target} ERR: {e}",
hdr.src_cid
);
vec![mk_resp(-111)] }
}
}
TSI_PROXY_RELEASE => {
Vec::new()
}
TSI_SENDTO_ADDR if body.len() >= 8 => {
let (Some(peer_port), Some(addr_len)) =
(body_le_u32(body, 0), body_le_u32(body, 4))
else {
return Vec::new();
};
let addr_len = addr_len as usize;
if body.len() < 8 + addr_len.min(128) || addr_len < 16 {
return Vec::new();
}
let Some(family) = body_le_u16(body, 8) else {
return Vec::new();
};
if family == 2 {
let Some(port) = body_be_u16(body, 10) else {
return Vec::new();
};
let ip = std::net::Ipv4Addr::new(body[12], body[13], body[14], body[15]);
let addr = SocketAddr::from((ip, port));
if let Err(why) = crate::vmm::egress_policy::check_addr(addr) {
eprintln!(
"[muxer] SENDTO_ADDR cid={} -> {addr} BLOCKED: {why}",
hdr.src_cid
);
} else {
self.udp_dst
.lock()
.unwrap()
.insert((hdr.src_cid, peer_port), addr);
}
}
Vec::new()
}
TSI_SENDTO_DATA => {
let peer_port = hdr.src_port;
self.send_udp_payload(hdr.src_cid, peer_port, hdr.src_port, body);
Vec::new()
}
_ => Vec::new(),
}
}
fn send_udp_payload(&self, cid: u64, peer_port: u32, guest_dst_port: u32, data: &[u8]) {
if data.is_empty() {
return;
}
let key = (cid, peer_port);
let dst = match self.udp_dst.lock().unwrap().get(&key).cloned() {
Some(a) => a,
None => {
if vsock_trace_enabled() {
eprintln!("[muxer] UDP DATA cid={cid} peer_port={peer_port} with no dst");
}
return;
}
};
let sock_arc = {
let mut s = self.udp_sockets.lock().unwrap();
if let Some(s) = s.get(&key) {
s.clone()
} else {
let udp = match UdpSocket::bind("0.0.0.0:0") {
Ok(u) => u,
Err(e) => {
eprintln!("[muxer] UDP bind cid={cid} peer_port={peer_port}: {e}");
return;
}
};
let recv_fd = match udp.try_clone() {
Ok(fd) => fd,
Err(e) => {
eprintln!("[muxer] UDP clone cid={cid} peer_port={peer_port}: {e}");
return;
}
};
let arc = Arc::new(udp);
s.insert(key, arc.clone());
let rxq = self.rxq.clone();
let kick = self.kick_device.lock().unwrap().clone();
let on_data: Arc<dyn Fn(Vec<u8>) + Send + Sync> = Arc::new(move |data| {
let pkt = RxPacket {
hdr: Header {
src_cid: VSOCK_HOST_CID,
dst_cid: cid,
src_port: peer_port,
dst_port: guest_dst_port,
len: data.len() as u32,
type_: VSOCK_TYPE_DGRAM,
op: VSOCK_OP_RW,
flags: 0,
buf_alloc: OUR_BUF_ALLOC,
fwd_cnt: 0,
},
data,
};
push_rxq_and_kick(&rxq, &kick, pkt);
});
let _ = self.io_tx.send(MuxerCmd::RegisterUdp {
key: peer_port,
udp: recv_fd,
on_data,
});
let _ = self.io_waker.wake();
arc
}
};
if vsock_trace_enabled() {
eprintln!(
"[muxer] UDP DATA cid={cid} peer_port={peer_port} guest_dst={guest_dst_port} -> {dst} {}B",
data.len()
);
}
let _ = sock_arc.send_to(data, dst);
}
}
fn push_rxq_and_kick(
rxq: &Arc<Mutex<MuxerRxQ>>,
kick: &Option<Arc<dyn Fn() + Send + Sync>>,
pkt: RxPacket,
) {
let bytes = pkt.data.len();
let t0 = Instant::now();
let was_empty = rxq.lock().unwrap().push_was_empty(pkt);
if was_empty == Some(true) {
if let Some(k) = kick.as_ref() {
k();
}
}
mux_profile::record(Stage::RxqKick, bytes, t0.elapsed().as_micros() as u64);
}
const _: u64 = VSOCK_HOST_CID;
#[cfg(test)]
mod tests {
use super::*;
fn control_header(dst_port: u32) -> Header {
Header {
src_cid: 3,
dst_cid: VSOCK_HOST_CID,
src_port: 40_000,
dst_port,
len: 0,
type_: VSOCK_TYPE_DGRAM,
op: VSOCK_OP_RW,
flags: 0,
buf_alloc: OUR_BUF_ALLOC,
fwd_cnt: 0,
}
}
fn response_i32(resp: &[RxPacket]) -> Option<i32> {
let bytes = resp.first()?.data.get(0..4)?.try_into().ok()?;
Some(i32::from_le_bytes(bytes))
}
#[test]
fn malformed_tsi_connect_returns_einval() -> Result<(), Box<dyn std::error::Error>> {
let muxer = VsockMuxer::new(3)?;
let hdr = control_header(TSI_CONNECT);
let mut body = Vec::new();
body.extend_from_slice(&1234u32.to_le_bytes());
body.extend_from_slice(&16u32.to_le_bytes());
let resp = muxer.handle_tsi_control(&hdr, &body);
assert_eq!(response_i32(&resp), Some(-22));
Ok(())
}
#[test]
fn malformed_tsi_sendto_addr_is_ignored() -> Result<(), Box<dyn std::error::Error>> {
let muxer = VsockMuxer::new(3)?;
let hdr = control_header(TSI_SENDTO_ADDR);
let mut body = Vec::new();
body.extend_from_slice(&1234u32.to_le_bytes());
body.extend_from_slice(&16u32.to_le_bytes());
let resp = muxer.handle_tsi_control(&hdr, &body);
assert!(resp.is_empty());
Ok(())
}
fn proxy_create_body(peer_port: u32) -> Vec<u8> {
let mut b = Vec::new();
b.extend_from_slice(&peer_port.to_le_bytes());
b.extend_from_slice(&2u16.to_le_bytes());
b.extend_from_slice(&1u16.to_le_bytes());
b
}
#[test]
fn tsi_control_rejects_missing_token_prefix() -> Result<(), Box<dyn std::error::Error>> {
let token = [0x11u8; TSI_TOKEN_LEN];
let muxer = VsockMuxer::with_tsi_token(3, Some(token))?;
let hdr = control_header(TSI_PROXY_CREATE);
let body = proxy_create_body(1234);
let resp = muxer.handle_tsi_control(&hdr, &body);
assert!(
resp.is_empty(),
"unprefixed TSI control DGRAM must be dropped, got {} pkts",
resp.len()
);
assert!(
muxer.tsi.lock().unwrap().is_empty(),
"dropped DGRAM still mutated tsi state"
);
Ok(())
}
#[test]
fn tsi_control_rejects_wrong_token_prefix() -> Result<(), Box<dyn std::error::Error>> {
let token = [0x11u8; TSI_TOKEN_LEN];
let muxer = VsockMuxer::with_tsi_token(3, Some(token))?;
let hdr = control_header(TSI_PROXY_CREATE);
let mut body = vec![0x22u8; TSI_TOKEN_LEN];
body.extend_from_slice(&proxy_create_body(1234));
let resp = muxer.handle_tsi_control(&hdr, &body);
assert!(resp.is_empty(), "wrong-token DGRAM must be dropped");
assert!(muxer.tsi.lock().unwrap().is_empty());
Ok(())
}
#[test]
fn tsi_control_accepts_correct_token_prefix() -> Result<(), Box<dyn std::error::Error>> {
let token = [0x33u8; TSI_TOKEN_LEN];
let muxer = VsockMuxer::with_tsi_token(3, Some(token))?;
let hdr = control_header(TSI_PROXY_CREATE);
let mut body = token.to_vec();
body.extend_from_slice(&proxy_create_body(1234));
let resp = muxer.handle_tsi_control(&hdr, &body);
assert!(resp.is_empty(), "PROXY_CREATE should not reply");
assert_eq!(
muxer.tsi.lock().unwrap().len(),
1,
"authenticated PROXY_CREATE should have inserted exactly one state"
);
Ok(())
}
#[test]
fn tsi_control_disabled_token_accepts_all() -> Result<(), Box<dyn std::error::Error>> {
let muxer = VsockMuxer::new(3)?;
let hdr = control_header(TSI_PROXY_CREATE);
let body = proxy_create_body(1234);
let resp = muxer.handle_tsi_control(&hdr, &body);
assert!(resp.is_empty(), "PROXY_CREATE doesn't reply");
assert_eq!(muxer.tsi.lock().unwrap().len(), 1);
Ok(())
}
#[test]
fn ct_eq_matches_pairs_and_short_inputs() {
let a = [0x55u8; TSI_TOKEN_LEN];
let b = [0x55u8; TSI_TOKEN_LEN];
assert!(ct_eq(&a, &b));
let mut c = a;
c[0] ^= 1;
assert!(!ct_eq(&a, &c));
c = a;
c[TSI_TOKEN_LEN - 1] ^= 1;
assert!(!ct_eq(&a, &c));
assert!(!ct_eq(&a[..16], &a[..]));
assert!(!ct_eq(&[], &a[..]));
assert!(!ct_eq(&a[..], &[]));
}
}