use std::collections::{HashMap, HashSet};
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::time::Instant as StdInstant;
use smoltcp::iface::SocketHandle;
use smoltcp::iface::{Interface, SocketSet};
use smoltcp::socket::tcp;
use smoltcp::wire::IpEndpoint;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::{mpsc, oneshot};
use crate::darwin::smoltcp_device::{SmoltcpDevice, TcpSynInfo};
use crate::ethernet::ETH_HEADER_LEN;
use crate::nat_engine::checksum;
const SOCKET_BUF_SIZE: usize = 256 * 1024;
const HOST_TO_GUEST_CHANNEL: usize = 64;
const GUEST_TO_HOST_CHANNEL: usize = 64;
const INBOUND_EPHEMERAL_START: u16 = 61000;
const INBOUND_EPHEMERAL_END: u16 = 65535;
const MAX_PENDING_SYNS: usize = 256;
const SYN_GATE_CONNECT_TIMEOUT_SECS: u64 = 5;
const PRE_CONNECTED_TTL_SECS: u64 = 10;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct SynFlowKey {
src_ip: Ipv4Addr,
src_port: u16,
dst_ip: Ipv4Addr,
dst_port: u16,
}
struct PendingSyn {
frame: Vec<u8>,
syn_seq: u32,
result_rx: oneshot::Receiver<Option<tokio::net::TcpStream>>,
created: StdInstant,
}
struct PreConnected {
stream: tokio::net::TcpStream,
syn_seq: u32,
created: StdInstant,
}
struct BridgedConn {
handle: SocketHandle,
remote: SocketAddr,
host_to_guest_rx: mpsc::Receiver<Vec<u8>>,
guest_to_host_tx: Option<mpsc::Sender<Vec<u8>>>,
host_eof: bool,
host_disconnected: bool,
pending_send: Option<Vec<u8>>,
}
pub struct TcpBridge {
connections: HashMap<SocketHandle, BridgedConn>,
listening_ports: HashSet<u16>,
port_handles: HashMap<u16, Vec<SocketHandle>>,
next_ephemeral: u16,
pending_syns: HashMap<SynFlowKey, PendingSyn>,
pre_connected: HashMap<SynFlowKey, PreConnected>,
}
impl Default for TcpBridge {
fn default() -> Self {
Self::new()
}
}
impl TcpBridge {
pub fn new() -> Self {
Self {
connections: HashMap::new(),
listening_ports: HashSet::new(),
port_handles: HashMap::new(),
next_ephemeral: INBOUND_EPHEMERAL_START,
pending_syns: HashMap::new(),
pre_connected: HashMap::new(),
}
}
pub fn ensure_listen_sockets(
&mut self,
syn_ports: &[crate::darwin::smoltcp_device::TcpSynInfo],
sockets: &mut SocketSet<'_>,
) {
for syn in syn_ports {
let port = syn.dst_port;
if self.listening_ports.contains(&port) {
continue;
}
let rx_buf = tcp::SocketBuffer::new(vec![0u8; SOCKET_BUF_SIZE]);
let tx_buf = tcp::SocketBuffer::new(vec![0u8; SOCKET_BUF_SIZE]);
let mut sock = tcp::Socket::new(rx_buf, tx_buf);
if let Err(e) = sock.listen(port) {
tracing::warn!("TCP bridge: failed to listen on port {port}: {e:?}");
continue;
}
sock.set_nagle_enabled(false);
sock.set_ack_delay(None);
let handle = sockets.add(sock);
self.listening_ports.insert(port);
self.port_handles.entry(port).or_default().push(handle);
tracing::debug!("TCP bridge: listen socket created for port {port}");
}
}
pub fn gate_syns(&mut self, syns: &[TcpSynInfo], gateway_mac: [u8; 6]) -> Vec<Vec<u8>> {
let mut rst_frames = Vec::new();
for syn in syns {
let key = SynFlowKey {
src_ip: syn.src_ip,
src_port: syn.src_port,
dst_ip: syn.dst_ip,
dst_port: syn.dst_port,
};
if let Some(existing) = self.pending_syns.get(&key) {
if existing.syn_seq == syn.syn_seq {
tracing::debug!("TCP SYN gate: retransmit dropped for {key:?}");
continue;
}
tracing::debug!("TCP SYN gate: ISN changed for {key:?}, replacing pending");
self.pending_syns.remove(&key);
}
if let Some(pre) = self.pre_connected.get(&key) {
if pre.syn_seq == syn.syn_seq {
tracing::debug!(
"TCP SYN gate: retransmit dropped (pre-connected exists) for {key:?}"
);
continue;
}
tracing::debug!(
"TCP SYN gate: ISN changed for {key:?}, evicting stale pre-connected stream"
);
self.pre_connected.remove(&key);
}
if self.pending_syns.len() >= MAX_PENDING_SYNS {
tracing::warn!("TCP SYN gate: capacity limit reached, sending RST for {key:?}");
if let Some(rst) = build_rst_from_syn(&syn.frame, gateway_mac) {
rst_frames.push(rst);
}
continue;
}
let dst_addr = SocketAddr::V4(SocketAddrV4::new(syn.dst_ip, syn.dst_port));
let (result_tx, result_rx) = oneshot::channel();
tokio::spawn(async move {
let result = tokio::time::timeout(
std::time::Duration::from_secs(SYN_GATE_CONNECT_TIMEOUT_SECS),
tokio::net::TcpStream::connect(dst_addr),
)
.await;
let stream = match result {
Ok(Ok(s)) => {
tracing::debug!("TCP SYN gate: connected to {dst_addr}");
Some(s)
}
Ok(Err(e)) => {
tracing::debug!("TCP SYN gate: connect to {dst_addr} failed: {e}");
None
}
Err(_) => {
tracing::debug!("TCP SYN gate: connect to {dst_addr} timed out");
None
}
};
let _ = result_tx.send(stream);
});
self.pending_syns.insert(
key,
PendingSyn {
frame: syn.frame.clone(),
syn_seq: syn.syn_seq,
result_rx,
created: StdInstant::now(),
},
);
tracing::debug!(
"TCP SYN gate: host connect started for {}:{} → {}:{}",
syn.src_ip,
syn.src_port,
syn.dst_ip,
syn.dst_port,
);
}
rst_frames
}
pub fn poll_pending_syns(
&mut self,
device: &mut SmoltcpDevice,
sockets: &mut SocketSet<'_>,
gateway_mac: [u8; 6],
) -> Vec<Vec<u8>> {
let mut rst_frames = Vec::new();
let mut completed = Vec::new();
for (key, pending) in &mut self.pending_syns {
match pending.result_rx.try_recv() {
Ok(Some(stream)) => {
completed.push((*key, Some(stream)));
}
Ok(None) => {
completed.push((*key, None));
}
Err(oneshot::error::TryRecvError::Empty) => {
if pending.created.elapsed()
> std::time::Duration::from_secs(SYN_GATE_CONNECT_TIMEOUT_SECS + 1)
{
completed.push((*key, None));
}
}
Err(oneshot::error::TryRecvError::Closed) => {
completed.push((*key, None));
}
}
}
for (key, result) in completed {
let pending = self.pending_syns.remove(&key).unwrap();
match result {
Some(stream) => {
if !self.ensure_listen_socket_for_port(key.dst_port, sockets) {
if let Some(rst) = build_rst_from_syn(&pending.frame, gateway_mac) {
rst_frames.push(rst);
tracing::debug!("TCP SYN gate: listen failed, sending RST for {key:?}");
}
continue;
}
device.inject_rx(pending.frame);
self.pre_connected.insert(
key,
PreConnected {
stream,
syn_seq: pending.syn_seq,
created: StdInstant::now(),
},
);
tracing::debug!(
"TCP SYN gate: injected SYN + stored pre-connected stream for {key:?}"
);
}
None => {
if let Some(rst) = build_rst_from_syn(&pending.frame, gateway_mac) {
rst_frames.push(rst);
tracing::debug!("TCP SYN gate: sending RST for failed connect {key:?}");
}
}
}
}
self.pre_connected.retain(|key, pre| {
if pre.created.elapsed() > std::time::Duration::from_secs(PRE_CONNECTED_TTL_SECS) {
tracing::debug!("TCP SYN gate: pre-connected stream expired for {key:?}");
false
} else {
true
}
});
rst_frames
}
fn ensure_listen_socket_for_port(&mut self, port: u16, sockets: &mut SocketSet<'_>) -> bool {
if self.listening_ports.contains(&port) {
return true;
}
let rx_buf = tcp::SocketBuffer::new(vec![0u8; SOCKET_BUF_SIZE]);
let tx_buf = tcp::SocketBuffer::new(vec![0u8; SOCKET_BUF_SIZE]);
let mut sock = tcp::Socket::new(rx_buf, tx_buf);
if let Err(e) = sock.listen(port) {
tracing::warn!("TCP bridge: failed to listen on port {port}: {e:?}");
return false;
}
sock.set_nagle_enabled(false);
sock.set_ack_delay(None);
let handle = sockets.add(sock);
self.listening_ports.insert(port);
self.port_handles.entry(port).or_default().push(handle);
tracing::debug!("TCP bridge: listen socket created for port {port}");
true
}
pub fn initiate_inbound(
&mut self,
container_port: u16,
stream: tokio::net::TcpStream,
guest_ip: Ipv4Addr,
gateway_ip: Ipv4Addr,
iface: &mut Interface,
sockets: &mut SocketSet<'_>,
) {
let eph_port = self.allocate_ephemeral();
let rx_buf = tcp::SocketBuffer::new(vec![0u8; SOCKET_BUF_SIZE]);
let tx_buf = tcp::SocketBuffer::new(vec![0u8; SOCKET_BUF_SIZE]);
let mut sock = tcp::Socket::new(rx_buf, tx_buf);
sock.set_nagle_enabled(false);
sock.set_ack_delay(None);
let local_ep = IpEndpoint::new(gateway_ip.into(), eph_port);
let remote_ep = IpEndpoint::new(guest_ip.into(), container_port);
if let Err(e) = sock.connect(iface.context(), remote_ep, local_ep) {
tracing::warn!("TCP bridge: inbound connect to guest:{container_port} failed: {e:?}");
return;
}
let handle = sockets.add(sock);
let (h2g_tx, h2g_rx) = mpsc::channel::<Vec<u8>>(HOST_TO_GUEST_CHANNEL);
let (g2h_tx, g2h_rx) = mpsc::channel::<Vec<u8>>(GUEST_TO_HOST_CHANNEL);
tokio::spawn(inbound_host_relay(stream, h2g_tx, g2h_rx));
let guest_addr = SocketAddr::V4(SocketAddrV4::new(guest_ip, container_port));
self.connections.insert(
handle,
BridgedConn {
handle,
remote: guest_addr,
host_to_guest_rx: h2g_rx,
guest_to_host_tx: Some(g2h_tx),
host_eof: false,
host_disconnected: false,
pending_send: None,
},
);
tracing::debug!(
"TCP bridge: inbound connect initiated gw:{eph_port} → guest:{container_port}"
);
}
fn allocate_ephemeral(&mut self) -> u16 {
let port = self.next_ephemeral;
self.next_ephemeral = if self.next_ephemeral == INBOUND_EPHEMERAL_END {
INBOUND_EPHEMERAL_START
} else {
self.next_ephemeral + 1
};
port
}
pub fn poll(&mut self, sockets: &mut SocketSet<'_>) {
self.detect_new_connections(sockets);
self.relay_all(sockets);
self.cleanup_closed(sockets);
}
fn detect_new_connections(&mut self, sockets: &mut SocketSet<'_>) {
let ports: Vec<u16> = self.listening_ports.iter().copied().collect();
let mut ports_to_replenish = Vec::new();
for port in ports {
let Some(handles) = self.port_handles.get(&port) else {
continue;
};
for &handle in handles {
if self.connections.contains_key(&handle) {
continue;
}
let sock = sockets.get_mut::<tcp::Socket>(handle);
if !sock.is_active() {
continue;
}
let Some(remote_ep) = sock.remote_endpoint() else {
continue;
};
let Some(local_ep) = sock.local_endpoint() else {
continue;
};
let remote_addr = endpoint_to_sockaddr(remote_ep);
let dest_addr = endpoint_to_sockaddr(local_ep);
tracing::debug!(
"TCP bridge: new connection detected guest:{remote_addr} → {dest_addr}"
);
let flow_key = SynFlowKey {
src_ip: match remote_ep.addr {
smoltcp::wire::IpAddress::Ipv4(v4) => v4,
},
src_port: remote_ep.port,
dst_ip: match local_ep.addr {
smoltcp::wire::IpAddress::Ipv4(v4) => v4,
},
dst_port: local_ep.port,
};
let (h2g_tx, h2g_rx) = mpsc::channel::<Vec<u8>>(HOST_TO_GUEST_CHANNEL);
let (g2h_tx, g2h_rx) = mpsc::channel::<Vec<u8>>(GUEST_TO_HOST_CHANNEL);
if let Some(pre) = self.pre_connected.remove(&flow_key) {
tracing::debug!(
"TCP bridge: using pre-connected stream for guest:{remote_addr} → {dest_addr}"
);
tokio::spawn(inbound_host_relay(pre.stream, h2g_tx, g2h_rx));
} else {
tracing::debug!(
"TCP bridge: no pre-connected stream, spawning connect for guest:{remote_addr} → {dest_addr}"
);
tokio::spawn(host_conn_task(dest_addr, h2g_tx, g2h_rx));
}
self.connections.insert(
handle,
BridgedConn {
handle,
remote: dest_addr,
host_to_guest_rx: h2g_rx,
guest_to_host_tx: Some(g2h_tx),
host_eof: false,
host_disconnected: false,
pending_send: None,
},
);
self.listening_ports.remove(&port);
ports_to_replenish.push(port);
}
}
for port in ports_to_replenish {
self.replenish_listen_socket(port, sockets);
}
}
fn replenish_listen_socket(&mut self, port: u16, sockets: &mut SocketSet<'_>) {
if self.listening_ports.contains(&port) {
return;
}
let rx_buf = tcp::SocketBuffer::new(vec![0u8; SOCKET_BUF_SIZE]);
let tx_buf = tcp::SocketBuffer::new(vec![0u8; SOCKET_BUF_SIZE]);
let mut sock = tcp::Socket::new(rx_buf, tx_buf);
if let Err(e) = sock.listen(port) {
tracing::warn!("TCP bridge: failed to listen on port {port}: {e:?}");
return;
}
sock.set_nagle_enabled(false);
sock.set_ack_delay(None);
let handle = sockets.add(sock);
self.listening_ports.insert(port);
self.port_handles.entry(port).or_default().push(handle);
tracing::debug!("TCP bridge: replenished listen socket for port {port}");
}
fn relay_all(&mut self, sockets: &mut SocketSet<'_>) {
for conn in self.connections.values_mut() {
let sock = sockets.get_mut::<tcp::Socket>(conn.handle);
while sock.can_send() {
if let Some(pending) = conn.pending_send.take() {
match sock.send_slice(&pending) {
Ok(sent) if sent < pending.len() => {
conn.pending_send = Some(pending[sent..].to_vec());
break;
}
Ok(_) => {} Err(e) => {
tracing::debug!("TCP bridge: send error to {}: {e:?}", conn.remote);
conn.pending_send = Some(pending);
break;
}
}
}
match conn.host_to_guest_rx.try_recv() {
Ok(data) => {
if data.is_empty() {
conn.host_eof = true;
break;
}
tracing::debug!(
"TCP bridge: h2g relay {} bytes to {}",
data.len(),
conn.remote
);
match sock.send_slice(&data) {
Ok(sent) if sent < data.len() => {
conn.pending_send = Some(data[sent..].to_vec());
break;
}
Err(e) => {
tracing::debug!("TCP bridge: send error to {}: {e:?}", conn.remote);
break;
}
_ => {}
}
}
Err(mpsc::error::TryRecvError::Empty) => break,
Err(mpsc::error::TryRecvError::Disconnected) => {
conn.host_disconnected = true;
break;
}
}
}
if !conn.host_disconnected
&& !conn.host_eof
&& conn.pending_send.is_none()
&& !sock.can_send()
{
match conn.host_to_guest_rx.try_recv() {
Err(mpsc::error::TryRecvError::Disconnected) => {
conn.host_disconnected = true;
}
Ok(data) if data.is_empty() => {
conn.host_eof = true;
}
Ok(data) => {
conn.pending_send = Some(data);
}
Err(mpsc::error::TryRecvError::Empty) => {}
}
}
if conn.host_eof && conn.pending_send.is_none() && sock.may_send() {
sock.close();
conn.host_eof = false;
}
if conn.host_disconnected && !conn.host_eof {
match sock.state() {
tcp::State::SynSent | tcp::State::SynReceived => {
sock.abort();
continue;
}
_ => {
sock.close();
}
}
}
if sock.may_recv() {
let _ = sock.recv(|buf| {
if buf.is_empty() {
return (0, ());
}
tracing::debug!(
"TCP bridge: g2h relay {} bytes from {}",
buf.len(),
conn.remote
);
match conn.guest_to_host_tx.as_ref() {
Some(tx) => match tx.try_send(buf.to_vec()) {
Ok(()) => (buf.len(), ()),
Err(mpsc::error::TrySendError::Full(_)) => {
(0, ())
}
Err(mpsc::error::TrySendError::Closed(_)) => {
(buf.len(), ())
}
},
None => {
(buf.len(), ())
}
}
});
}
if conn.guest_to_host_tx.is_some() {
let guest_fin_received = matches!(
sock.state(),
tcp::State::CloseWait
| tcp::State::LastAck
| tcp::State::Closing
| tcp::State::TimeWait
| tcp::State::Closed
);
if guest_fin_received {
conn.guest_to_host_tx.take();
}
}
}
}
fn cleanup_closed(&mut self, sockets: &mut SocketSet<'_>) {
let mut to_remove = Vec::new();
for (&handle, conn) in &self.connections {
let sock = sockets.get_mut::<tcp::Socket>(handle);
if !sock.is_open() || sock.state() == tcp::State::Closed {
to_remove.push(handle);
tracing::debug!("TCP bridge: connection to {} closed", conn.remote);
}
}
for handle in to_remove {
self.connections.remove(&handle);
self.port_handles.retain(|_, handles| {
handles.retain(|h| *h != handle);
!handles.is_empty()
});
sockets.remove(handle);
}
}
pub fn active_count(&self) -> usize {
self.connections.len()
}
}
fn build_rst_from_syn(syn_frame: &[u8], gateway_mac: [u8; 6]) -> Option<Vec<u8>> {
let ip_start = ETH_HEADER_LEN;
if syn_frame.len() < ip_start + 40 {
return None;
}
let ihl = ((syn_frame[ip_start] & 0x0F) as usize) * 4;
let l4_start = ip_start + ihl;
if l4_start + 20 > syn_frame.len() {
return None;
}
let src_mac = &syn_frame[6..12];
let syn_src_ip = [
syn_frame[ip_start + 12],
syn_frame[ip_start + 13],
syn_frame[ip_start + 14],
syn_frame[ip_start + 15],
];
let syn_dst_ip = [
syn_frame[ip_start + 16],
syn_frame[ip_start + 17],
syn_frame[ip_start + 18],
syn_frame[ip_start + 19],
];
let syn_src_port = u16::from_be_bytes([syn_frame[l4_start], syn_frame[l4_start + 1]]);
let syn_dst_port = u16::from_be_bytes([syn_frame[l4_start + 2], syn_frame[l4_start + 3]]);
let syn_seq = u32::from_be_bytes([
syn_frame[l4_start + 4],
syn_frame[l4_start + 5],
syn_frame[l4_start + 6],
syn_frame[l4_start + 7],
]);
let mut frame = vec![0u8; ETH_HEADER_LEN + 40];
frame[0..6].copy_from_slice(src_mac);
frame[6..12].copy_from_slice(&gateway_mac);
frame[12..14].copy_from_slice(&[0x08, 0x00]);
let ip = ETH_HEADER_LEN;
frame[ip] = 0x45; frame[ip + 2..ip + 4].copy_from_slice(&40u16.to_be_bytes()); frame[ip + 6..ip + 8].copy_from_slice(&0x4000u16.to_be_bytes()); frame[ip + 8] = 64; frame[ip + 9] = 6; frame[ip + 12..ip + 16].copy_from_slice(&syn_dst_ip);
frame[ip + 16..ip + 20].copy_from_slice(&syn_src_ip);
let ip_cksum = checksum::ipv4_header_checksum(&frame[ip..ip + 20]);
frame[ip + 10..ip + 12].copy_from_slice(&ip_cksum.to_be_bytes());
let tcp_start = ip + 20;
frame[tcp_start..tcp_start + 2].copy_from_slice(&syn_dst_port.to_be_bytes()); frame[tcp_start + 2..tcp_start + 4].copy_from_slice(&syn_src_port.to_be_bytes()); frame[tcp_start + 4..tcp_start + 8].copy_from_slice(&0u32.to_be_bytes());
frame[tcp_start + 8..tcp_start + 12].copy_from_slice(&(syn_seq.wrapping_add(1)).to_be_bytes());
frame[tcp_start + 12] = 0x50; frame[tcp_start + 13] = 0x14; frame[tcp_start + 14..tcp_start + 16].copy_from_slice(&0u16.to_be_bytes());
let tcp_cksum =
checksum::tcp_checksum(syn_dst_ip, syn_src_ip, &frame[tcp_start..tcp_start + 20]);
frame[tcp_start + 16..tcp_start + 18].copy_from_slice(&tcp_cksum.to_be_bytes());
Some(frame)
}
async fn host_conn_task(
remote: SocketAddr,
h2g_tx: mpsc::Sender<Vec<u8>>,
mut g2h_rx: mpsc::Receiver<Vec<u8>>,
) {
let connect_started = StdInstant::now();
tracing::debug!("TCP bridge: host_conn_task started for {remote}");
let stream = match tokio::time::timeout(
std::time::Duration::from_secs(10),
tokio::net::TcpStream::connect(remote),
)
.await
{
Ok(Ok(s)) => s,
Ok(Err(e)) => {
tracing::debug!(
"TCP bridge: connect to {remote} failed after {:?}: {e}",
connect_started.elapsed()
);
return;
}
Err(_) => {
tracing::debug!(
"TCP bridge: connect to {remote} timed out after {:?}",
connect_started.elapsed()
);
return;
}
};
tracing::debug!(
"TCP bridge: connected to {remote} in {:?}",
connect_started.elapsed()
);
let (mut reader, mut writer) = stream.into_split();
let read_task = {
let h2g_tx = h2g_tx.clone();
tokio::spawn(async move {
let mut buf = vec![0u8; 32768];
loop {
match reader.read(&mut buf).await {
Ok(0) => {
let _ = h2g_tx.send(Vec::new()).await;
break;
}
Ok(n) => {
tracing::debug!("TCP bridge: host read {n} bytes from {remote}");
if h2g_tx.send(buf[..n].to_vec()).await.is_err() {
break;
}
}
Err(e) => {
tracing::debug!("TCP bridge: host read error for {remote}: {e}");
break;
}
}
}
})
};
let write_task = tokio::spawn(async move {
while let Some(data) = g2h_rx.recv().await {
if data.is_empty() {
let _ = writer.shutdown().await;
tracing::debug!("TCP bridge: host writer got guest EOF for {remote}");
break;
}
tracing::debug!("TCP bridge: host write {} bytes to {remote}", data.len());
if let Err(e) = writer.write_all(&data).await {
tracing::debug!("TCP bridge: host write error for {remote}: {e}");
break;
}
}
});
let _ = tokio::join!(read_task, write_task);
}
async fn inbound_host_relay(
stream: tokio::net::TcpStream,
h2g_tx: mpsc::Sender<Vec<u8>>,
mut g2h_rx: mpsc::Receiver<Vec<u8>>,
) {
let peer = stream
.peer_addr()
.map_or_else(|_| "unknown".into(), |a| a.to_string());
tracing::debug!("TCP bridge: inbound relay started for {peer}");
let (mut reader, mut writer) = stream.into_split();
let read_task = {
let h2g_tx = h2g_tx.clone();
let peer = peer.clone();
tokio::spawn(async move {
let mut buf = vec![0u8; 32768];
loop {
match reader.read(&mut buf).await {
Ok(0) => {
tracing::debug!("TCP bridge: inbound host EOF for {peer}");
let _ = h2g_tx.send(Vec::new()).await;
break;
}
Ok(n) => {
tracing::debug!("TCP bridge: inbound host read {n} bytes from {peer}");
if h2g_tx.send(buf[..n].to_vec()).await.is_err() {
break;
}
}
Err(e) => {
tracing::debug!("TCP bridge: inbound host read error for {peer}: {e}");
break;
}
}
}
})
};
let write_task = tokio::spawn(async move {
while let Some(data) = g2h_rx.recv().await {
if data.is_empty() {
tracing::debug!("TCP bridge: inbound guest EOF for {peer}");
let _ = writer.shutdown().await;
break;
}
tracing::debug!(
"TCP bridge: inbound host write {} bytes to {peer}",
data.len()
);
if let Err(e) = writer.write_all(&data).await {
tracing::debug!("TCP bridge: inbound host write error for {peer}: {e}");
break;
}
}
});
let _ = tokio::join!(read_task, write_task);
}
fn endpoint_to_sockaddr(ep: IpEndpoint) -> SocketAddr {
let smoltcp::wire::IpAddress::Ipv4(v4) = ep.addr;
SocketAddr::V4(SocketAddrV4::new(v4, ep.port))
}
#[cfg(test)]
mod tests {
use super::*;
use smoltcp::iface::{Config, Interface};
use smoltcp::wire::{EthernetAddress, IpCidr};
use crate::darwin::smoltcp_device::{SmoltcpDevice, TcpSynInfo};
use crate::ethernet::ETH_HEADER_LEN;
const GW_IP: Ipv4Addr = Ipv4Addr::new(192, 168, 64, 1);
const GW_MAC: [u8; 6] = [0x02, 0x00, 0x00, 0x00, 0x00, 0x01];
const GUEST_MAC: [u8; 6] = [0x02, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE];
const GUEST_IP: Ipv4Addr = Ipv4Addr::new(192, 168, 64, 2);
fn make_iface_and_sockets(device: &mut SmoltcpDevice) -> (Interface, SocketSet<'static>) {
let hw_addr = EthernetAddress(GW_MAC);
let config = Config::new(hw_addr.into());
let mut iface = Interface::new(config, device, smoltcp::time::Instant::now());
iface.update_ip_addrs(|addrs| {
addrs.push(IpCidr::new(GW_IP.into(), 24)).unwrap();
});
iface.set_any_ip(true);
iface.routes_mut().add_default_ipv4_route(GW_IP).unwrap();
let sockets = SocketSet::new(vec![]);
(iface, sockets)
}
fn make_syn_frame(dst_ip: Ipv4Addr, dst_port: u16) -> Vec<u8> {
let mut frame = vec![0u8; ETH_HEADER_LEN + 40]; frame[0..6].copy_from_slice(&GW_MAC); frame[6..12].copy_from_slice(&GUEST_MAC); frame[12..14].copy_from_slice(&[0x08, 0x00]); let ip = ETH_HEADER_LEN;
frame[ip] = 0x45;
frame[ip + 2..ip + 4].copy_from_slice(&40u16.to_be_bytes()); frame[ip + 8] = 64; frame[ip + 9] = 6; frame[ip + 12..ip + 16].copy_from_slice(&GUEST_IP.octets());
frame[ip + 16..ip + 20].copy_from_slice(&dst_ip.octets());
let cksum = ip_checksum(&frame[ip..ip + 20]);
frame[ip + 10..ip + 12].copy_from_slice(&cksum.to_be_bytes());
let tcp = ip + 20;
frame[tcp..tcp + 2].copy_from_slice(&12345u16.to_be_bytes()); frame[tcp + 2..tcp + 4].copy_from_slice(&dst_port.to_be_bytes()); frame[tcp + 4..tcp + 8].copy_from_slice(&1000u32.to_be_bytes()); frame[tcp + 12] = 0x50; frame[tcp + 13] = 0x02; frame[tcp + 14..tcp + 16].copy_from_slice(&65535u16.to_be_bytes()); let tcp_cksum = tcp_checksum(&GUEST_IP.octets(), &dst_ip.octets(), &frame[tcp..ip + 40]);
frame[tcp + 16..tcp + 18].copy_from_slice(&tcp_cksum.to_be_bytes());
frame
}
fn ip_checksum(header: &[u8]) -> u16 {
let mut sum: u32 = 0;
let mut i = 0;
while i + 1 < header.len() {
if i != 10 {
sum += u32::from(u16::from_be_bytes([header[i], header[i + 1]]));
}
i += 2;
}
while sum > 0xFFFF {
sum = (sum & 0xFFFF) + (sum >> 16);
}
!sum as u16
}
fn tcp_checksum(src_ip: &[u8; 4], dst_ip: &[u8; 4], tcp_segment: &[u8]) -> u16 {
let mut sum: u32 = 0;
sum += u32::from(u16::from_be_bytes([src_ip[0], src_ip[1]]));
sum += u32::from(u16::from_be_bytes([src_ip[2], src_ip[3]]));
sum += u32::from(u16::from_be_bytes([dst_ip[0], dst_ip[1]]));
sum += u32::from(u16::from_be_bytes([dst_ip[2], dst_ip[3]]));
sum += 6u32; sum += tcp_segment.len() as u32;
let mut i = 0;
while i + 1 < tcp_segment.len() {
if i != 16 {
sum += u32::from(u16::from_be_bytes([tcp_segment[i], tcp_segment[i + 1]]));
}
i += 2;
}
if i < tcp_segment.len() {
sum += u32::from(tcp_segment[i]) << 8;
}
while sum > 0xFFFF {
sum = (sum & 0xFFFF) + (sum >> 16);
}
!sum as u16
}
#[test]
fn ensure_listen_sockets_creates_on_demand() {
let mut device = SmoltcpDevice::new(0, GW_IP);
let (_iface, mut sockets) = make_iface_and_sockets(&mut device);
let mut bridge = TcpBridge::new();
let syns = vec![
TcpSynInfo {
dst_port: 443,
src_ip: GUEST_IP,
src_port: 1000,
dst_ip: Ipv4Addr::new(1, 1, 1, 1),
syn_seq: 0,
frame: vec![],
},
TcpSynInfo {
dst_port: 80,
src_ip: GUEST_IP,
src_port: 1001,
dst_ip: Ipv4Addr::new(1, 1, 1, 1),
syn_seq: 0,
frame: vec![],
},
];
bridge.ensure_listen_sockets(&syns, &mut sockets);
assert!(bridge.listening_ports.contains(&443));
assert!(bridge.listening_ports.contains(&80));
assert!(bridge.port_handles.contains_key(&443));
assert!(bridge.port_handles.contains_key(&80));
}
#[test]
fn ensure_listen_sockets_deduplicates() {
let mut device = SmoltcpDevice::new(0, GW_IP);
let (_iface, mut sockets) = make_iface_and_sockets(&mut device);
let mut bridge = TcpBridge::new();
let syns = vec![TcpSynInfo {
dst_port: 443,
src_ip: GUEST_IP,
src_port: 1000,
dst_ip: Ipv4Addr::new(1, 1, 1, 1),
syn_seq: 0,
frame: vec![],
}];
bridge.ensure_listen_sockets(&syns, &mut sockets);
bridge.ensure_listen_sockets(&syns, &mut sockets);
assert_eq!(bridge.port_handles[&443].len(), 1);
}
#[test]
fn smoltcp_accepts_syn_with_listen_socket() {
let mut device = SmoltcpDevice::new(0, GW_IP);
let (mut iface, mut sockets) = make_iface_and_sockets(&mut device);
let mut bridge = TcpBridge::new();
let syns = vec![TcpSynInfo {
dst_port: 443,
src_ip: GUEST_IP,
src_port: 1000,
dst_ip: Ipv4Addr::new(1, 1, 1, 1),
syn_seq: 0,
frame: vec![],
}];
bridge.ensure_listen_sockets(&syns, &mut sockets);
let mut arp = vec![0u8; 42];
arp[0..6].copy_from_slice(&[0xFF; 6]); arp[6..12].copy_from_slice(&GUEST_MAC);
arp[12..14].copy_from_slice(&[0x08, 0x06]); arp[14..16].copy_from_slice(&[0x00, 0x01]); arp[16..18].copy_from_slice(&[0x08, 0x00]); arp[18] = 6; arp[19] = 4; arp[20..22].copy_from_slice(&[0x00, 0x01]); arp[22..28].copy_from_slice(&GUEST_MAC);
arp[28..32].copy_from_slice(&GUEST_IP.octets());
arp[32..38].copy_from_slice(&[0x00; 6]);
arp[38..42].copy_from_slice(&GW_IP.octets());
device.inject_rx(arp);
let ts = smoltcp::time::Instant::now();
iface.poll(ts, &mut device, &mut sockets);
let _ = device.take_tx_pending();
let syn = make_syn_frame(Ipv4Addr::new(1, 1, 1, 1), 443);
device.inject_rx(syn);
let ts = smoltcp::time::Instant::now();
iface.poll(ts, &mut device, &mut sockets);
let handle = bridge.port_handles[&443][0];
let sock = sockets.get_mut::<tcp::Socket>(handle);
assert!(
sock.is_active(),
"Socket should be active after SYN; state={:?}",
sock.state()
);
}
#[test]
fn bridge_active_count_starts_at_zero() {
let bridge = TcpBridge::new();
assert_eq!(bridge.active_count(), 0);
}
#[test]
fn partial_send_preserves_remainder() {
let mut device = SmoltcpDevice::new(0, GW_IP);
let (_iface, mut sockets) = make_iface_and_sockets(&mut device);
let rx_buf = tcp::SocketBuffer::new(vec![0u8; 64]);
let tx_buf = tcp::SocketBuffer::new(vec![0u8; 16]);
let mut sock = tcp::Socket::new(rx_buf, tx_buf);
sock.set_nagle_enabled(false);
let handle = sockets.add(sock);
let (_h2g_tx, h2g_rx) = mpsc::channel::<Vec<u8>>(4);
let (g2h_tx, _g2h_rx) = mpsc::channel::<Vec<u8>>(4);
let mut bridge = TcpBridge::new();
let remote: SocketAddr = "1.1.1.1:443".parse().unwrap();
bridge.connections.insert(
handle,
BridgedConn {
handle,
remote,
host_to_guest_rx: h2g_rx,
guest_to_host_tx: Some(g2h_tx),
host_eof: false,
host_disconnected: false,
pending_send: Some(vec![0xAA; 32]),
},
);
bridge.relay_all(&mut sockets);
let conn = bridge.connections.get(&handle).unwrap();
assert!(
conn.pending_send.is_some(),
"Pending data should be preserved when socket can't send"
);
assert_eq!(conn.pending_send.as_ref().unwrap().len(), 32);
}
#[test]
fn host_eof_waits_for_pending_send() {
let mut device = SmoltcpDevice::new(0, GW_IP);
let (_iface, mut sockets) = make_iface_and_sockets(&mut device);
let rx_buf = tcp::SocketBuffer::new(vec![0u8; 64]);
let tx_buf = tcp::SocketBuffer::new(vec![0u8; 64]);
let sock = tcp::Socket::new(rx_buf, tx_buf);
let handle = sockets.add(sock);
let (_h2g_tx, h2g_rx) = mpsc::channel::<Vec<u8>>(4);
let (g2h_tx, _g2h_rx) = mpsc::channel::<Vec<u8>>(4);
let mut bridge = TcpBridge::new();
let remote: SocketAddr = "1.1.1.1:443".parse().unwrap();
bridge.connections.insert(
handle,
BridgedConn {
handle,
remote,
host_to_guest_rx: h2g_rx,
guest_to_host_tx: Some(g2h_tx),
host_eof: true,
host_disconnected: false,
pending_send: Some(vec![0xBB; 10]),
},
);
bridge.relay_all(&mut sockets);
let conn = bridge.connections.get(&handle).unwrap();
assert!(
conn.host_eof,
"host_eof should remain set while pending_send is non-empty"
);
}
#[test]
fn cleanup_removes_stale_port_handles() {
let mut device = SmoltcpDevice::new(0, GW_IP);
let (_iface, mut sockets) = make_iface_and_sockets(&mut device);
let mut bridge = TcpBridge::new();
let syns = vec![TcpSynInfo {
dst_port: 443,
src_ip: GUEST_IP,
src_port: 1000,
dst_ip: Ipv4Addr::new(1, 1, 1, 1),
syn_seq: 0,
frame: vec![],
}];
bridge.ensure_listen_sockets(&syns, &mut sockets);
assert!(bridge.port_handles.contains_key(&443));
let handle = bridge.port_handles[&443][0];
let sock = sockets.get_mut::<tcp::Socket>(handle);
sock.abort();
let (_h2g_tx, h2g_rx) = mpsc::channel::<Vec<u8>>(1);
let (g2h_tx, _g2h_rx) = mpsc::channel::<Vec<u8>>(1);
bridge.connections.insert(
handle,
BridgedConn {
handle,
remote: "1.1.1.1:443".parse().unwrap(),
host_to_guest_rx: h2g_rx,
guest_to_host_tx: Some(g2h_tx),
host_eof: false,
host_disconnected: false,
pending_send: None,
},
);
bridge.cleanup_closed(&mut sockets);
assert!(
!bridge.port_handles.contains_key(&443),
"port_handles should be cleaned up after socket removal"
);
assert!(bridge.connections.is_empty());
}
#[test]
fn guest_eof_drops_sender() {
let mut device = SmoltcpDevice::new(0, GW_IP);
let (_iface, mut sockets) = make_iface_and_sockets(&mut device);
let rx_buf = tcp::SocketBuffer::new(vec![0u8; 64]);
let tx_buf = tcp::SocketBuffer::new(vec![0u8; 64]);
let sock = tcp::Socket::new(rx_buf, tx_buf);
let handle = sockets.add(sock);
let (_h2g_tx, h2g_rx) = mpsc::channel::<Vec<u8>>(4);
let (g2h_tx, mut g2h_rx) = mpsc::channel::<Vec<u8>>(4);
let mut bridge = TcpBridge::new();
let remote: SocketAddr = "1.1.1.1:443".parse().unwrap();
bridge.connections.insert(
handle,
BridgedConn {
handle,
remote,
host_to_guest_rx: h2g_rx,
guest_to_host_tx: Some(g2h_tx),
host_eof: false,
host_disconnected: false,
pending_send: None,
},
);
bridge.relay_all(&mut sockets);
let conn = bridge.connections.get(&handle).unwrap();
assert!(
conn.guest_to_host_tx.is_none(),
"guest_to_host_tx should be taken after EOF"
);
drop(bridge);
assert!(
g2h_rx.try_recv().is_err(),
"Receiver should see disconnect after sender replaced and dropped"
);
}
#[tokio::test]
async fn initiate_inbound_creates_connecting_socket() {
let mut device = SmoltcpDevice::new(0, GW_IP);
let (mut iface, mut sockets) = make_iface_and_sockets(&mut device);
let mut bridge = TcpBridge::new();
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let connect = tokio::net::TcpStream::connect(addr);
let (stream, _accepted) = tokio::join!(connect, listener.accept());
let stream = stream.unwrap();
bridge.initiate_inbound(80, stream, GUEST_IP, GW_IP, &mut iface, &mut sockets);
assert_eq!(
bridge.active_count(),
1,
"Should have one inbound connection"
);
let (handle, conn) = bridge.connections.iter().next().unwrap();
let sock = sockets.get_mut::<tcp::Socket>(*handle);
assert!(
sock.is_open(),
"Socket should be open after connect; state={:?}",
sock.state()
);
assert_eq!(conn.remote, SocketAddr::V4(SocketAddrV4::new(GUEST_IP, 80)));
}
#[tokio::test]
async fn syn_gate_connect_success_injects_syn() {
let mut device = SmoltcpDevice::new(0, GW_IP);
let (_iface, mut sockets) = make_iface_and_sockets(&mut device);
let mut bridge = TcpBridge::new();
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let syn = make_syn_frame(addr.ip().to_string().parse().unwrap(), addr.port());
let syn_info = TcpSynInfo {
dst_port: addr.port(),
src_ip: GUEST_IP,
src_port: 12345,
dst_ip: addr.ip().to_string().parse().unwrap(),
syn_seq: 1000,
frame: syn.clone(),
};
bridge.gate_syns(&[syn_info], GW_MAC);
assert_eq!(bridge.pending_syns.len(), 1);
let _accepted = listener.accept().await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let rst_frames = bridge.poll_pending_syns(&mut device, &mut sockets, GW_MAC);
assert!(
rst_frames.is_empty(),
"No RST should be generated on success"
);
assert!(bridge.pending_syns.is_empty(), "Pending should be consumed");
assert_eq!(
bridge.pre_connected.len(),
1,
"Should have pre-connected stream"
);
assert_eq!(device.take_tx_pending().len(), 0); assert!(bridge.listening_ports.contains(&addr.port()));
}
#[tokio::test]
async fn syn_gate_connect_failure_sends_rst() {
let mut device = SmoltcpDevice::new(0, GW_IP);
let (_iface, mut sockets) = make_iface_and_sockets(&mut device);
let mut bridge = TcpBridge::new();
let syn = make_syn_frame(Ipv4Addr::LOCALHOST, 1);
let syn_info = TcpSynInfo {
dst_port: 1,
src_ip: GUEST_IP,
src_port: 12345,
dst_ip: Ipv4Addr::LOCALHOST,
syn_seq: 1000,
frame: syn,
};
bridge.gate_syns(&[syn_info], GW_MAC);
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
let rst_frames = bridge.poll_pending_syns(&mut device, &mut sockets, GW_MAC);
assert_eq!(rst_frames.len(), 1, "Should generate exactly one RST");
assert!(bridge.pending_syns.is_empty());
assert!(bridge.pre_connected.is_empty());
let rst = &rst_frames[0];
assert!(rst.len() >= ETH_HEADER_LEN + 40);
let ip = ETH_HEADER_LEN;
let tcp_start = ip + 20;
assert_eq!(rst[tcp_start + 13], 0x14, "Flags should be RST|ACK");
let ack = u32::from_be_bytes([
rst[tcp_start + 8],
rst[tcp_start + 9],
rst[tcp_start + 10],
rst[tcp_start + 11],
]);
assert_eq!(ack, 1001, "ACK should be syn_seq + 1");
assert_eq!(&rst[0..6], &GUEST_MAC);
assert_eq!(&rst[6..12], &GW_MAC);
}
#[tokio::test]
async fn syn_gate_retransmit_dedup() {
let mut bridge = TcpBridge::new();
let syn = make_syn_frame(Ipv4Addr::new(1, 1, 1, 1), 443);
let syn_info = TcpSynInfo {
dst_port: 443,
src_ip: GUEST_IP,
src_port: 12345,
dst_ip: Ipv4Addr::new(1, 1, 1, 1),
syn_seq: 1000,
frame: syn.clone(),
};
bridge.gate_syns(std::slice::from_ref(&syn_info), GW_MAC);
assert_eq!(bridge.pending_syns.len(), 1);
bridge.gate_syns(&[syn_info], GW_MAC);
assert_eq!(bridge.pending_syns.len(), 1);
let syn_info_new_isn = TcpSynInfo {
dst_port: 443,
src_ip: GUEST_IP,
src_port: 12345,
dst_ip: Ipv4Addr::new(1, 1, 1, 1),
syn_seq: 5000,
frame: syn,
};
bridge.gate_syns(&[syn_info_new_isn], GW_MAC);
assert_eq!(bridge.pending_syns.len(), 1);
let key = SynFlowKey {
src_ip: GUEST_IP,
src_port: 12345,
dst_ip: Ipv4Addr::new(1, 1, 1, 1),
dst_port: 443,
};
assert_eq!(bridge.pending_syns[&key].syn_seq, 5000);
}
#[tokio::test]
async fn pre_connected_expires_after_ttl() {
let mut device = SmoltcpDevice::new(0, GW_IP);
let (_iface, mut sockets) = make_iface_and_sockets(&mut device);
let mut bridge = TcpBridge::new();
let key = SynFlowKey {
src_ip: GUEST_IP,
src_port: 12345,
dst_ip: Ipv4Addr::new(1, 1, 1, 1),
dst_port: 443,
};
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let connect = tokio::net::TcpStream::connect(addr);
let (stream, _accepted) = tokio::join!(connect, listener.accept());
let stream = stream.unwrap();
bridge.pre_connected.insert(
key,
PreConnected {
stream,
syn_seq: 1000,
created: StdInstant::now()
- std::time::Duration::from_secs(PRE_CONNECTED_TTL_SECS + 1),
},
);
assert_eq!(bridge.pre_connected.len(), 1);
let rst_frames = bridge.poll_pending_syns(&mut device, &mut sockets, GW_MAC);
assert!(rst_frames.is_empty());
assert!(
bridge.pre_connected.is_empty(),
"Expired entry should be removed"
);
}
#[tokio::test]
async fn pre_connected_same_isn_retransmit_dedup() {
let mut bridge = TcpBridge::new();
let key = SynFlowKey {
src_ip: GUEST_IP,
src_port: 12345,
dst_ip: Ipv4Addr::new(1, 1, 1, 1),
dst_port: 443,
};
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let connect = tokio::net::TcpStream::connect(addr);
let (stream, _accepted) = tokio::join!(connect, listener.accept());
let stream = stream.unwrap();
bridge.pre_connected.insert(
key,
PreConnected {
stream,
syn_seq: 1000,
created: StdInstant::now(),
},
);
let syn = make_syn_frame(Ipv4Addr::new(1, 1, 1, 1), 443);
let syn_info = TcpSynInfo {
dst_port: 443,
src_ip: GUEST_IP,
src_port: 12345,
dst_ip: Ipv4Addr::new(1, 1, 1, 1),
syn_seq: 1000,
frame: syn,
};
bridge.gate_syns(&[syn_info], GW_MAC);
assert!(
bridge.pending_syns.is_empty(),
"Same ISN retransmit should not create a new pending entry"
);
assert_eq!(
bridge.pre_connected.len(),
1,
"Pre-connected stream should be preserved"
);
}
#[tokio::test]
async fn pre_connected_different_isn_evicts_stale_stream() {
let mut bridge = TcpBridge::new();
let key = SynFlowKey {
src_ip: GUEST_IP,
src_port: 12345,
dst_ip: Ipv4Addr::new(1, 1, 1, 1),
dst_port: 443,
};
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let connect = tokio::net::TcpStream::connect(addr);
let (stream, _accepted) = tokio::join!(connect, listener.accept());
let stream = stream.unwrap();
bridge.pre_connected.insert(
key,
PreConnected {
stream,
syn_seq: 1000,
created: StdInstant::now(),
},
);
let syn = make_syn_frame(Ipv4Addr::new(1, 1, 1, 1), 443);
let syn_info = TcpSynInfo {
dst_port: 443,
src_ip: GUEST_IP,
src_port: 12345,
dst_ip: Ipv4Addr::new(1, 1, 1, 1),
syn_seq: 5000,
frame: syn,
};
bridge.gate_syns(&[syn_info], GW_MAC);
assert!(
bridge.pre_connected.is_empty(),
"Stale pre-connected stream should be evicted"
);
assert_eq!(
bridge.pending_syns.len(),
1,
"New ISN should create a new pending entry"
);
assert_eq!(bridge.pending_syns[&key].syn_seq, 5000);
}
}