use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::sync::atomic::{AtomicU16, Ordering};
use bytes::Bytes;
use smoltcp::iface::{Interface, SocketHandle, SocketSet};
use smoltcp::socket::tcp;
use smoltcp::wire::IpEndpoint;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc;
use crate::config::{PortProtocol, PublishedPort};
use crate::shared::SharedState;
const TCP_RX_BUF_SIZE: usize = 65536;
const TCP_TX_BUF_SIZE: usize = 65536;
const CHANNEL_CAPACITY: usize = 32;
const RELAY_BUF_SIZE: usize = 16384;
pub struct PortPublisher {
inbound_rx: mpsc::Receiver<InboundConnection>,
_inbound_tx: mpsc::Sender<InboundConnection>,
connections: Vec<InboundRelay>,
guest_ipv4: Ipv4Addr,
ephemeral_port: Arc<AtomicU16>,
max_inbound: usize,
}
struct InboundConnection {
stream: TcpStream,
guest_port: u16,
}
const DEFERRED_CLOSE_LIMIT: u16 = 64;
struct InboundRelay {
handle: SocketHandle,
to_host: mpsc::Sender<Bytes>,
from_host: mpsc::Receiver<Bytes>,
write_buf: Option<(Bytes, usize)>,
close_attempts: u16,
}
impl PortPublisher {
pub fn new(
ports: &[PublishedPort],
guest_ipv4: Ipv4Addr,
tokio_handle: &tokio::runtime::Handle,
) -> Self {
let (inbound_tx, inbound_rx) = mpsc::channel(64);
for port in ports {
if port.protocol == PortProtocol::Tcp {
let tx = inbound_tx.clone();
let bind_addr = SocketAddr::new(port.host_bind, port.host_port);
let guest_port = port.guest_port;
tokio_handle.spawn(async move {
if let Err(e) = tcp_listener_task(bind_addr, guest_port, tx).await {
tracing::error!(
bind = %bind_addr,
error = %e,
"published port listener failed",
);
}
});
}
}
Self {
inbound_rx,
_inbound_tx: inbound_tx,
connections: Vec::new(),
guest_ipv4,
ephemeral_port: Arc::new(AtomicU16::new(49152)),
max_inbound: 256,
}
}
pub fn accept_inbound(
&mut self,
iface: &mut Interface,
sockets: &mut SocketSet<'_>,
shared: &Arc<SharedState>,
tokio_handle: &tokio::runtime::Handle,
) {
while let Ok(conn) = self.inbound_rx.try_recv() {
if self.connections.len() >= self.max_inbound {
tracing::debug!("published port: max inbound connections reached, rejecting");
continue;
}
let rx_buf = tcp::SocketBuffer::new(vec![0u8; TCP_RX_BUF_SIZE]);
let tx_buf = tcp::SocketBuffer::new(vec![0u8; TCP_TX_BUF_SIZE]);
let mut socket = tcp::Socket::new(rx_buf, tx_buf);
let remote = IpEndpoint::new(IpAddr::V4(self.guest_ipv4).into(), conn.guest_port);
let local_port = self.alloc_ephemeral_port();
if socket.connect(iface.context(), remote, local_port).is_err() {
tracing::debug!(
guest_port = conn.guest_port,
"failed to connect smoltcp socket to guest",
);
continue;
}
let handle = sockets.add(socket);
let (to_host_tx, to_host_rx) = mpsc::channel(CHANNEL_CAPACITY);
let (from_host_tx, from_host_rx) = mpsc::channel(CHANNEL_CAPACITY);
let shared_clone = shared.clone();
tokio_handle.spawn(async move {
let _ =
inbound_relay_task(conn.stream, to_host_rx, from_host_tx, shared_clone).await;
});
self.connections.push(InboundRelay {
handle,
to_host: to_host_tx,
from_host: from_host_rx,
write_buf: None,
close_attempts: 0,
});
}
}
pub fn relay_data(&mut self, sockets: &mut SocketSet<'_>) {
let mut relay_buf = [0u8; RELAY_BUF_SIZE];
for relay in &mut self.connections {
let socket = sockets.get_mut::<tcp::Socket>(relay.handle);
if relay.to_host.is_closed() {
write_host_data(socket, relay);
if relay.write_buf.is_none() {
socket.close();
} else {
relay.close_attempts += 1;
if relay.close_attempts >= DEFERRED_CLOSE_LIMIT {
socket.abort();
}
}
continue;
}
while socket.can_recv() {
match socket.recv_slice(&mut relay_buf) {
Ok(n) if n > 0 => {
let data = Bytes::copy_from_slice(&relay_buf[..n]);
if relay.to_host.try_send(data).is_err() {
break;
}
}
_ => break,
}
}
write_host_data(socket, relay);
}
}
pub fn cleanup_closed(&mut self, sockets: &mut SocketSet<'_>) {
self.connections.retain(|relay| {
let socket = sockets.get::<tcp::Socket>(relay.handle);
let closed = matches!(socket.state(), tcp::State::Closed);
if closed {
sockets.remove(relay.handle);
}
!closed
});
}
}
impl PortPublisher {
fn alloc_ephemeral_port(&self) -> u16 {
loop {
let port = self.ephemeral_port.fetch_add(1, Ordering::Relaxed);
if port == 0 || port < 49152 {
self.ephemeral_port.store(49152, Ordering::Relaxed);
continue;
}
return port;
}
}
}
async fn tcp_listener_task(
bind_addr: SocketAddr,
guest_port: u16,
inbound_tx: mpsc::Sender<InboundConnection>,
) -> std::io::Result<()> {
let listener = TcpListener::bind(bind_addr).await?;
tracing::debug!(bind = %bind_addr, guest_port, "published port listener started");
loop {
let (stream, _peer) = listener.accept().await?;
let conn = InboundConnection { stream, guest_port };
if inbound_tx.send(conn).await.is_err() {
break; }
}
Ok(())
}
async fn inbound_relay_task(
stream: TcpStream,
mut to_host_rx: mpsc::Receiver<Bytes>,
from_host_tx: mpsc::Sender<Bytes>,
shared: Arc<SharedState>,
) -> std::io::Result<()> {
let (mut rx, mut tx) = stream.into_split();
let mut buf = vec![0u8; RELAY_BUF_SIZE];
loop {
tokio::select! {
data = to_host_rx.recv() => {
match data {
Some(bytes) => {
if let Err(e) = tx.write_all(&bytes).await {
tracing::debug!(error = %e, "write to host client failed");
break;
}
}
None => break,
}
}
result = rx.read(&mut buf) => {
match result {
Ok(0) => break,
Ok(n) => {
let data = Bytes::copy_from_slice(&buf[..n]);
if from_host_tx.send(data).await.is_err() {
break;
}
shared.proxy_wake.wake();
}
Err(e) => {
tracing::debug!(error = %e, "read from host client failed");
break;
}
}
}
}
}
Ok(())
}
fn write_host_data(socket: &mut tcp::Socket<'_>, relay: &mut InboundRelay) {
if let Some((data, offset)) = &mut relay.write_buf {
if socket.can_send() {
match socket.send_slice(&data[*offset..]) {
Ok(written) => {
*offset += written;
if *offset >= data.len() {
relay.write_buf = None;
}
}
Err(_) => return,
}
} else {
return;
}
}
while relay.write_buf.is_none() {
match relay.from_host.try_recv() {
Ok(data) => {
if socket.can_send() {
match socket.send_slice(&data) {
Ok(written) if written < data.len() => {
relay.write_buf = Some((data, written));
}
Err(_) => {
relay.write_buf = Some((data, 0));
}
_ => {}
}
} else {
relay.write_buf = Some((data, 0));
}
}
Err(_) => break,
}
}
}