use crate::queues::WakePipe;
use crate::virtio_net_log;
use smoltcp::iface::{Interface, SocketHandle, SocketSet};
use smoltcp::socket::tcp;
use smoltcp::wire::IpListenEndpoint;
use std::collections::{HashMap, HashSet};
use std::io::{self, Read, Write};
use std::net::{Ipv4Addr, Shutdown, SocketAddr, TcpStream};
use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::mpsc::{self, Receiver, SyncSender, TryRecvError};
use std::sync::Arc;
use std::thread;
use std::time::Duration;
const TCP_RX_BUFFER_BYTES: usize = 64 * 1024;
const TCP_TX_BUFFER_BYTES: usize = 64 * 1024;
const MAX_CONNECTIONS: usize = 256;
const CHANNEL_CAPACITY: usize = 32;
const RELAY_BUFFER_BYTES: usize = 16 * 1024;
const CLOSE_RETRY_LIMIT: u16 = 64;
const PROXY_IDLE_SLEEP: Duration = Duration::from_millis(10);
const PUBLISHED_PORT_START: u16 = 49_152;
const PUBLISHED_PORT_END: u16 = 65_535;
pub struct TcpRelayTable {
connections: HashMap<SocketHandle, TrackedConnection>,
connection_keys: HashSet<(SocketAddr, SocketAddr)>,
used_published_ports: HashSet<u16>,
next_published_port: u16,
max_connections: usize,
}
pub struct NewTcpConnection {
pub destination: SocketAddr,
pub relay_target: RelayTarget,
pub from_smoltcp: Receiver<Vec<u8>>,
pub to_smoltcp: SyncSender<Vec<u8>>,
pub exit_state: RelayExitState,
}
#[derive(Debug)]
struct TrackedConnection {
source: SocketAddr,
destination: SocketAddr,
to_proxy: SyncSender<Vec<u8>>,
from_proxy: Receiver<Vec<u8>>,
pending_proxy_endpoints: Option<PendingProxyEndpoints>,
relay_spawned: bool,
buffered_proxy_data: Option<(Vec<u8>, usize)>,
close_attempts: u16,
exit_state: RelayExitState,
reserved_published_port: Option<u16>,
}
#[derive(Debug)]
struct PendingProxyEndpoints {
from_smoltcp: Receiver<Vec<u8>>,
to_smoltcp: SyncSender<Vec<u8>>,
relay_target: RelayTarget,
}
#[derive(Debug)]
pub enum RelayTarget {
Connect(SocketAddr),
Attached(TcpStream),
}
#[derive(Clone, Debug)]
pub struct RelayExitState {
inner: Arc<AtomicU8>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(u8)]
pub enum RelayExitMode {
Running = 0,
Graceful = 1,
Abort = 2,
}
impl RelayExitState {
fn new() -> Self {
Self {
inner: Arc::new(AtomicU8::new(RelayExitMode::Running as u8)),
}
}
fn load(&self) -> RelayExitMode {
match self.inner.load(Ordering::Relaxed) {
1 => RelayExitMode::Graceful,
2 => RelayExitMode::Abort,
_ => RelayExitMode::Running,
}
}
fn store(&self, mode: RelayExitMode) {
self.inner.store(mode as u8, Ordering::Relaxed);
}
}
impl TcpRelayTable {
pub fn new(max_connections: Option<usize>) -> Self {
Self {
connections: HashMap::new(),
connection_keys: HashSet::new(),
used_published_ports: HashSet::new(),
next_published_port: PUBLISHED_PORT_START,
max_connections: max_connections.unwrap_or(MAX_CONNECTIONS),
}
}
pub fn has_socket_for(&self, source: &SocketAddr, destination: &SocketAddr) -> bool {
self.connection_keys.contains(&(*source, *destination))
}
pub fn create_tcp_socket(
&mut self,
source: SocketAddr,
destination: SocketAddr,
sockets: &mut SocketSet<'_>,
) -> bool {
if self.connections.len() >= self.max_connections {
tracing::warn!("dropping TCP connection because the relay table is full");
return false;
}
let rx_buffer = tcp::SocketBuffer::new(vec![0u8; TCP_RX_BUFFER_BYTES]);
let tx_buffer = tcp::SocketBuffer::new(vec![0u8; TCP_TX_BUFFER_BYTES]);
let mut socket = tcp::Socket::new(rx_buffer, tx_buffer);
let std::net::IpAddr::V4(destination_ip) = destination.ip() else {
return false;
};
let listen_endpoint = IpListenEndpoint {
addr: Some(destination_ip.into()),
port: destination.port(),
};
if socket.listen(listen_endpoint).is_err() {
return false;
}
let handle = sockets.add(socket);
let (to_proxy_tx, to_proxy_rx) = mpsc::sync_channel(CHANNEL_CAPACITY);
let (from_proxy_tx, from_proxy_rx) = mpsc::sync_channel(CHANNEL_CAPACITY);
let exit_state = RelayExitState::new();
self.connection_keys.insert((source, destination));
self.connections.insert(
handle,
TrackedConnection {
source,
destination,
to_proxy: to_proxy_tx,
from_proxy: from_proxy_rx,
pending_proxy_endpoints: Some(PendingProxyEndpoints {
from_smoltcp: to_proxy_rx,
to_smoltcp: from_proxy_tx,
relay_target: RelayTarget::Connect(destination),
}),
relay_spawned: false,
buffered_proxy_data: None,
close_attempts: 0,
exit_state,
reserved_published_port: None,
},
);
true
}
pub fn create_published_socket(
&mut self,
interface: &mut Interface,
gateway_ip: Ipv4Addr,
destination: SocketAddr,
host_stream: TcpStream,
sockets: &mut SocketSet<'_>,
) -> bool {
if self.connections.len() >= self.max_connections {
tracing::warn!("dropping published TCP connection because the relay table is full");
return false;
}
let Some(local_port) = self.allocate_published_port() else {
tracing::warn!(
"dropping published TCP connection because no gateway source port is available"
);
return false;
};
let std::net::IpAddr::V4(destination_ip) = destination.ip() else {
self.used_published_ports.remove(&local_port);
return false;
};
let rx_buffer = tcp::SocketBuffer::new(vec![0u8; TCP_RX_BUFFER_BYTES]);
let tx_buffer = tcp::SocketBuffer::new(vec![0u8; TCP_TX_BUFFER_BYTES]);
let mut socket = tcp::Socket::new(rx_buffer, tx_buffer);
let local_endpoint = IpListenEndpoint {
addr: Some(gateway_ip.into()),
port: local_port,
};
if socket
.connect(
interface.context(),
(destination_ip, destination.port()),
local_endpoint,
)
.is_err()
{
self.used_published_ports.remove(&local_port);
return false;
}
let handle = sockets.add(socket);
let source = SocketAddr::new(std::net::IpAddr::V4(gateway_ip), local_port);
let (to_proxy_tx, to_proxy_rx) = mpsc::sync_channel(CHANNEL_CAPACITY);
let (from_proxy_tx, from_proxy_rx) = mpsc::sync_channel(CHANNEL_CAPACITY);
let exit_state = RelayExitState::new();
self.connection_keys.insert((source, destination));
self.connections.insert(
handle,
TrackedConnection {
source,
destination,
to_proxy: to_proxy_tx,
from_proxy: from_proxy_rx,
pending_proxy_endpoints: Some(PendingProxyEndpoints {
from_smoltcp: to_proxy_rx,
to_smoltcp: from_proxy_tx,
relay_target: RelayTarget::Attached(host_stream),
}),
relay_spawned: false,
buffered_proxy_data: None,
close_attempts: 0,
exit_state,
reserved_published_port: Some(local_port),
},
);
true
}
pub fn relay_data(&mut self, sockets: &mut SocketSet<'_>) {
let mut read_buffer = [0u8; RELAY_BUFFER_BYTES];
for (&handle, connection) in &mut self.connections {
if !connection.relay_spawned {
continue;
}
let socket = sockets.get_mut::<tcp::Socket>(handle);
match connection.exit_state.load() {
RelayExitMode::Abort => {
socket.abort();
continue;
}
RelayExitMode::Graceful => {
flush_proxy_data(socket, connection);
if connection.buffered_proxy_data.is_none() {
socket.close();
} else {
connection.close_attempts += 1;
if connection.close_attempts >= CLOSE_RETRY_LIMIT {
socket.abort();
}
}
continue;
}
RelayExitMode::Running => {}
}
while socket.can_recv() {
match socket.recv_slice(&mut read_buffer) {
Ok(bytes_read) if bytes_read > 0 => {
let payload = read_buffer[..bytes_read].to_vec();
if connection.to_proxy.try_send(payload).is_err() {
break;
}
}
_ => break,
}
}
flush_proxy_data(socket, connection);
}
}
pub fn take_new_connections(&mut self, sockets: &mut SocketSet<'_>) -> Vec<NewTcpConnection> {
let mut new_connections = Vec::new();
for (&handle, connection) in &mut self.connections {
if connection.relay_spawned {
continue;
}
let socket = sockets.get::<tcp::Socket>(handle);
if socket.state() == tcp::State::Established {
connection.relay_spawned = true;
if let Some(endpoints) = connection.pending_proxy_endpoints.take() {
new_connections.push(NewTcpConnection {
destination: connection.destination,
relay_target: endpoints.relay_target,
from_smoltcp: endpoints.from_smoltcp,
to_smoltcp: endpoints.to_smoltcp,
exit_state: connection.exit_state.clone(),
});
}
}
}
new_connections
}
pub fn cleanup_closed(&mut self, sockets: &mut SocketSet<'_>) {
let keys = &mut self.connection_keys;
let published_ports = &mut self.used_published_ports;
self.connections.retain(|&handle, connection| {
let socket = sockets.get::<tcp::Socket>(handle);
if socket.state() == tcp::State::Closed {
keys.remove(&(connection.source, connection.destination));
if let Some(port) = connection.reserved_published_port {
published_ports.remove(&port);
}
sockets.remove(handle);
false
} else {
true
}
});
}
fn allocate_published_port(&mut self) -> Option<u16> {
let start = self.next_published_port;
loop {
let candidate = self.next_published_port;
self.next_published_port = if candidate == PUBLISHED_PORT_END {
PUBLISHED_PORT_START
} else {
candidate + 1
};
if self.used_published_ports.insert(candidate) {
return Some(candidate);
}
if self.next_published_port == start {
return None;
}
}
}
}
pub fn spawn_tcp_relay(
destination: SocketAddr,
relay_target: RelayTarget,
from_smoltcp: Receiver<Vec<u8>>,
to_smoltcp: SyncSender<Vec<u8>>,
relay_wake: Arc<WakePipe>,
exit_state: RelayExitState,
) {
let thread_name = format!("smolvm-tcp-{}", destination.port());
virtio_net_log!(
"virtio-net: spawning host TCP relay thread destination={} thread={}",
destination,
thread_name
);
let _ = thread::Builder::new().name(thread_name).spawn(move || {
run_tcp_relay(
destination,
relay_target,
from_smoltcp,
to_smoltcp,
relay_wake,
exit_state,
)
});
}
fn run_tcp_relay(
destination: SocketAddr,
relay_target: RelayTarget,
from_smoltcp: Receiver<Vec<u8>>,
to_smoltcp: SyncSender<Vec<u8>>,
relay_wake: Arc<WakePipe>,
exit_state: RelayExitState,
) {
virtio_net_log!(
"virtio-net: host TCP relay thread started destination={}",
destination
);
match tcp_relay_loop(
destination,
relay_target,
from_smoltcp,
to_smoltcp,
relay_wake,
) {
Ok(mode) => {
virtio_net_log!(
"virtio-net: host TCP relay thread exited destination={} mode={:?}",
destination,
mode
);
exit_state.store(mode)
}
Err(err) => {
virtio_net_log!(
"virtio-net: host TCP relay failed destination={} error={}",
destination,
err
);
exit_state.store(RelayExitMode::Abort);
}
}
}
fn tcp_relay_loop(
destination: SocketAddr,
relay_target: RelayTarget,
from_smoltcp: Receiver<Vec<u8>>,
to_smoltcp: SyncSender<Vec<u8>>,
relay_wake: Arc<WakePipe>,
) -> io::Result<RelayExitMode> {
let mut stream = match relay_target {
RelayTarget::Connect(destination) => {
virtio_net_log!(
"virtio-net: connecting host TCP relay socket destination={}",
destination
);
let stream = TcpStream::connect(destination)?;
virtio_net_log!(
"virtio-net: host TCP relay socket connected destination={}",
destination
);
stream
}
RelayTarget::Attached(stream) => {
virtio_net_log!(
"virtio-net: using accepted host TCP socket for published port guest_destination={} peer_addr={:?} local_addr={:?}",
destination,
stream.peer_addr().ok(),
stream.local_addr().ok()
);
stream
}
};
stream.set_nonblocking(true)?;
let mut guest_write_closed = false;
let mut read_buffer = [0u8; RELAY_BUFFER_BYTES];
loop {
let mut did_work = false;
loop {
match from_smoltcp.try_recv() {
Ok(payload) => {
stream.write_all(&payload)?;
did_work = true;
}
Err(TryRecvError::Empty) => break,
Err(TryRecvError::Disconnected) => {
if !guest_write_closed {
let _ = stream.shutdown(Shutdown::Write);
guest_write_closed = true;
}
break;
}
}
}
match stream.read(&mut read_buffer) {
Ok(0) => return Ok(RelayExitMode::Graceful),
Ok(bytes_read) => {
if to_smoltcp.send(read_buffer[..bytes_read].to_vec()).is_err() {
return Ok(RelayExitMode::Graceful);
}
relay_wake.wake();
did_work = true;
}
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
Err(err) => return Err(err),
}
if !did_work {
thread::sleep(PROXY_IDLE_SLEEP);
}
}
}
fn flush_proxy_data(socket: &mut tcp::Socket<'_>, connection: &mut TrackedConnection) {
if let Some((data, offset)) = &mut connection.buffered_proxy_data {
if socket.can_send() {
match socket.send_slice(&data[*offset..]) {
Ok(written) => {
*offset += written;
if *offset >= data.len() {
connection.buffered_proxy_data = None;
}
}
Err(_) => return,
}
} else {
return;
}
}
while connection.buffered_proxy_data.is_none() {
match connection.from_proxy.try_recv() {
Ok(payload) => {
if socket.can_send() {
match socket.send_slice(&payload) {
Ok(written) if written < payload.len() => {
connection.buffered_proxy_data = Some((payload, written));
}
Err(_) => {
connection.buffered_proxy_data = Some((payload, 0));
}
_ => {}
}
} else {
connection.buffered_proxy_data = Some((payload, 0));
}
}
Err(TryRecvError::Empty | TryRecvError::Disconnected) => break,
}
}
}