1use std::net::{IpAddr, Ipv4Addr, SocketAddr};
10use std::sync::Arc;
11use std::sync::atomic::{AtomicU16, Ordering};
12
13use bytes::Bytes;
14use smoltcp::iface::{Interface, SocketHandle, SocketSet};
15use smoltcp::socket::tcp;
16use smoltcp::wire::IpEndpoint;
17use tokio::io::{AsyncReadExt, AsyncWriteExt};
18use tokio::net::{TcpListener, TcpStream};
19use tokio::sync::mpsc;
20
21use crate::config::{PortProtocol, PublishedPort};
22use crate::shared::SharedState;
23
24const TCP_RX_BUF_SIZE: usize = 65536;
30const TCP_TX_BUF_SIZE: usize = 65536;
31
32const CHANNEL_CAPACITY: usize = 32;
34
35const RELAY_BUF_SIZE: usize = 16384;
37
38pub struct PortPublisher {
48 inbound_rx: mpsc::Receiver<InboundConnection>,
50 _inbound_tx: mpsc::Sender<InboundConnection>,
52 connections: Vec<InboundRelay>,
54 guest_ipv4: Ipv4Addr,
56 ephemeral_port: Arc<AtomicU16>,
58 max_inbound: usize,
60}
61
62struct InboundConnection {
64 stream: TcpStream,
66 guest_port: u16,
68}
69
70const DEFERRED_CLOSE_LIMIT: u16 = 64;
73
74struct InboundRelay {
76 handle: SocketHandle,
77 to_host: mpsc::Sender<Bytes>,
79 from_host: mpsc::Receiver<Bytes>,
81 write_buf: Option<(Bytes, usize)>,
83 close_attempts: u16,
85}
86
87impl PortPublisher {
92 pub fn new(
94 ports: &[PublishedPort],
95 guest_ipv4: Ipv4Addr,
96 tokio_handle: &tokio::runtime::Handle,
97 ) -> Self {
98 let (inbound_tx, inbound_rx) = mpsc::channel(64);
99
100 for port in ports {
102 if port.protocol == PortProtocol::Tcp {
103 let tx = inbound_tx.clone();
104 let bind_addr = SocketAddr::new(port.host_bind, port.host_port);
105 let guest_port = port.guest_port;
106 tokio_handle.spawn(async move {
107 if let Err(e) = tcp_listener_task(bind_addr, guest_port, tx).await {
108 tracing::error!(
109 bind = %bind_addr,
110 error = %e,
111 "published port listener failed",
112 );
113 }
114 });
115 }
116 }
118
119 Self {
120 inbound_rx,
121 _inbound_tx: inbound_tx,
122 connections: Vec::new(),
123 guest_ipv4,
124 ephemeral_port: Arc::new(AtomicU16::new(49152)),
125 max_inbound: 256,
126 }
127 }
128
129 pub fn accept_inbound(
134 &mut self,
135 iface: &mut Interface,
136 sockets: &mut SocketSet<'_>,
137 shared: &Arc<SharedState>,
138 tokio_handle: &tokio::runtime::Handle,
139 ) {
140 while let Ok(conn) = self.inbound_rx.try_recv() {
141 if self.connections.len() >= self.max_inbound {
142 tracing::debug!("published port: max inbound connections reached, rejecting");
143 continue;
144 }
145 let rx_buf = tcp::SocketBuffer::new(vec![0u8; TCP_RX_BUF_SIZE]);
147 let tx_buf = tcp::SocketBuffer::new(vec![0u8; TCP_TX_BUF_SIZE]);
148 let mut socket = tcp::Socket::new(rx_buf, tx_buf);
149
150 let remote = IpEndpoint::new(IpAddr::V4(self.guest_ipv4).into(), conn.guest_port);
152 let local_port = self.alloc_ephemeral_port();
153
154 if socket.connect(iface.context(), remote, local_port).is_err() {
155 tracing::debug!(
156 guest_port = conn.guest_port,
157 "failed to connect smoltcp socket to guest",
158 );
159 continue;
160 }
161
162 let handle = sockets.add(socket);
163
164 let (to_host_tx, to_host_rx) = mpsc::channel(CHANNEL_CAPACITY);
166 let (from_host_tx, from_host_rx) = mpsc::channel(CHANNEL_CAPACITY);
167
168 let shared_clone = shared.clone();
170 tokio_handle.spawn(async move {
171 let _ =
172 inbound_relay_task(conn.stream, to_host_rx, from_host_tx, shared_clone).await;
173 });
174
175 self.connections.push(InboundRelay {
176 handle,
177 to_host: to_host_tx,
178 from_host: from_host_rx,
179 write_buf: None,
180 close_attempts: 0,
181 });
182 }
183 }
184
185 pub fn relay_data(&mut self, sockets: &mut SocketSet<'_>) {
187 let mut relay_buf = [0u8; RELAY_BUF_SIZE];
188
189 for relay in &mut self.connections {
190 let socket = sockets.get_mut::<tcp::Socket>(relay.handle);
191
192 if relay.to_host.is_closed() {
194 write_host_data(socket, relay);
195 if relay.write_buf.is_none() {
196 socket.close();
197 } else {
198 relay.close_attempts += 1;
201 if relay.close_attempts >= DEFERRED_CLOSE_LIMIT {
202 socket.abort();
203 }
204 }
205 continue;
206 }
207
208 while socket.can_recv() {
210 match socket.recv_slice(&mut relay_buf) {
211 Ok(n) if n > 0 => {
212 let data = Bytes::copy_from_slice(&relay_buf[..n]);
213 if relay.to_host.try_send(data).is_err() {
214 break;
215 }
216 }
217 _ => break,
218 }
219 }
220
221 write_host_data(socket, relay);
223 }
224 }
225
226 pub fn cleanup_closed(&mut self, sockets: &mut SocketSet<'_>) {
231 self.connections.retain(|relay| {
232 let socket = sockets.get::<tcp::Socket>(relay.handle);
233 let closed = matches!(socket.state(), tcp::State::Closed);
234 if closed {
235 sockets.remove(relay.handle);
236 }
237 !closed
238 });
239 }
240}
241
242impl PortPublisher {
243 fn alloc_ephemeral_port(&self) -> u16 {
244 loop {
245 let port = self.ephemeral_port.fetch_add(1, Ordering::Relaxed);
246 if port == 0 || port < 49152 {
248 self.ephemeral_port.store(49152, Ordering::Relaxed);
249 continue;
250 }
251 return port;
252 }
253 }
254}
255
256async fn tcp_listener_task(
262 bind_addr: SocketAddr,
263 guest_port: u16,
264 inbound_tx: mpsc::Sender<InboundConnection>,
265) -> std::io::Result<()> {
266 let listener = TcpListener::bind(bind_addr).await?;
267 tracing::debug!(bind = %bind_addr, guest_port, "published port listener started");
268
269 loop {
270 let (stream, _peer) = listener.accept().await?;
271 let conn = InboundConnection { stream, guest_port };
272 if inbound_tx.send(conn).await.is_err() {
273 break; }
275 }
276
277 Ok(())
278}
279
280async fn inbound_relay_task(
282 stream: TcpStream,
283 mut to_host_rx: mpsc::Receiver<Bytes>,
284 from_host_tx: mpsc::Sender<Bytes>,
285 shared: Arc<SharedState>,
286) -> std::io::Result<()> {
287 let (mut rx, mut tx) = stream.into_split();
288 let mut buf = vec![0u8; RELAY_BUF_SIZE];
289
290 loop {
291 tokio::select! {
292 data = to_host_rx.recv() => {
294 match data {
295 Some(bytes) => {
296 if let Err(e) = tx.write_all(&bytes).await {
297 tracing::debug!(error = %e, "write to host client failed");
298 break;
299 }
300 }
301 None => break,
302 }
303 }
304
305 result = rx.read(&mut buf) => {
307 match result {
308 Ok(0) => break,
309 Ok(n) => {
310 let data = Bytes::copy_from_slice(&buf[..n]);
311 if from_host_tx.send(data).await.is_err() {
312 break;
313 }
314 shared.proxy_wake.wake();
315 }
316 Err(e) => {
317 tracing::debug!(error = %e, "read from host client failed");
318 break;
319 }
320 }
321 }
322 }
323 }
324
325 Ok(())
326}
327
328fn write_host_data(socket: &mut tcp::Socket<'_>, relay: &mut InboundRelay) {
330 if let Some((data, offset)) = &mut relay.write_buf {
332 if socket.can_send() {
333 match socket.send_slice(&data[*offset..]) {
334 Ok(written) => {
335 *offset += written;
336 if *offset >= data.len() {
337 relay.write_buf = None;
338 }
339 }
340 Err(_) => return,
341 }
342 } else {
343 return;
344 }
345 }
346
347 while relay.write_buf.is_none() {
349 match relay.from_host.try_recv() {
350 Ok(data) => {
351 if socket.can_send() {
352 match socket.send_slice(&data) {
353 Ok(written) if written < data.len() => {
354 relay.write_buf = Some((data, written));
355 }
356 Err(_) => {
357 relay.write_buf = Some((data, 0));
358 }
359 _ => {}
360 }
361 } else {
362 relay.write_buf = Some((data, 0));
363 }
364 }
365 Err(_) => break,
366 }
367 }
368}