1use std::{
8 collections::HashMap,
9 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
10 sync::Arc,
11 time::{Duration, Instant},
12};
13
14use tokio::{
15 net::{TcpListener, TcpStream, UdpSocket},
16 sync::Mutex,
17};
18
19use crate::config::{PortProtocol, PublishedPort};
20
21const UDP_SESSION_TIMEOUT: Duration = Duration::from_secs(60);
27
28const MAX_UDP_SESSIONS: usize = 1024;
30
31pub struct PortPublisher {
37 _handles: Vec<tokio::task::JoinHandle<()>>,
39}
40
41pub struct GuestAddresses {
43 pub ipv4: Option<Ipv4Addr>,
44 pub ipv6: Option<Ipv6Addr>,
45}
46
47struct UdpSession {
49 guest_socket: Arc<UdpSocket>,
51
52 last_active: Instant,
54}
55
56pub enum PendingListener {
58 Tcp {
60 listener: TcpListener,
62 guest_addr: SocketAddr,
64 },
65 Udp {
67 socket: UdpSocket,
69 guest_addr: SocketAddr,
71 },
72}
73
74impl PortPublisher {
79 pub async fn bind(
85 ports: &[PublishedPort],
86 guest_ipv4: Option<Ipv4Addr>,
87 guest_ipv6: Option<Ipv6Addr>,
88 ) -> std::io::Result<Vec<PendingListener>> {
89 let mut pending = Vec::with_capacity(ports.len());
90 let guest_addresses = GuestAddresses {
91 ipv4: guest_ipv4,
92 ipv6: guest_ipv6,
93 };
94
95 for port in ports {
96 let guest_ip = resolve_guest_ip(port.host_bind, &guest_addresses)?;
97 let host_bind = SocketAddr::new(port.host_bind, port.host_port);
98 let guest_addr = SocketAddr::new(guest_ip, port.guest_port);
99
100 match port.protocol {
101 PortProtocol::Tcp => {
102 let listener = TcpListener::bind(host_bind).await?;
103 tracing::info!(%host_bind, %guest_addr, "published TCP port");
104 pending.push(PendingListener::Tcp {
105 listener,
106 guest_addr,
107 });
108 }
109 PortProtocol::Udp => {
110 let socket = UdpSocket::bind(host_bind).await?;
111 tracing::info!(%host_bind, %guest_addr, "published UDP port");
112 pending.push(PendingListener::Udp { socket, guest_addr });
113 }
114 }
115 }
116
117 Ok(pending)
118 }
119
120 pub fn start_from(pending: Vec<PendingListener>) -> Self {
126 let mut handles = Vec::with_capacity(pending.len());
127
128 for listener in pending {
129 let handle = match listener {
130 PendingListener::Tcp {
131 listener,
132 guest_addr,
133 } => tokio::spawn(async move {
134 tcp_listener_loop(listener, guest_addr).await;
135 }),
136 PendingListener::Udp { socket, guest_addr } => tokio::spawn(async move {
137 udp_relay_loop(socket, guest_addr).await;
138 }),
139 };
140 handles.push(handle);
141 }
142
143 Self { _handles: handles }
144 }
145
146 pub async fn start(
148 ports: &[PublishedPort],
149 guest_ipv4: Option<Ipv4Addr>,
150 guest_ipv6: Option<Ipv6Addr>,
151 ) -> std::io::Result<Self> {
152 let pending = Self::bind(ports, guest_ipv4, guest_ipv6).await?;
153 Ok(Self::start_from(pending))
154 }
155}
156
157async fn tcp_listener_loop(listener: TcpListener, guest_addr: SocketAddr) {
163 loop {
164 let (client_stream, peer_addr) = match listener.accept().await {
165 Ok(conn) => conn,
166 Err(e) => {
167 tracing::warn!("TCP accept error: {e}");
168 continue;
169 }
170 };
171
172 tracing::debug!(%peer_addr, %guest_addr, "TCP connection accepted");
173
174 tokio::spawn(async move {
175 if let Err(e) = tcp_proxy(client_stream, guest_addr).await {
176 tracing::debug!(%peer_addr, %guest_addr, "TCP proxy ended: {e}");
177 }
178 });
179 }
180}
181
182async fn tcp_proxy(mut client: TcpStream, guest_addr: SocketAddr) -> std::io::Result<()> {
184 let mut guest = TcpStream::connect(guest_addr).await?;
185 tokio::io::copy_bidirectional(&mut client, &mut guest).await?;
186 Ok(())
187}
188
189async fn udp_relay_loop(host_socket: UdpSocket, guest_addr: SocketAddr) {
195 let host_socket = Arc::new(host_socket);
196 let sessions: Arc<Mutex<HashMap<SocketAddr, UdpSession>>> =
197 Arc::new(Mutex::new(HashMap::new()));
198
199 let mut buf = [0u8; 65535];
200
201 loop {
202 let (n, peer_addr) = match host_socket.recv_from(&mut buf).await {
203 Ok(result) => result,
204 Err(e) => {
205 tracing::warn!("UDP recv error: {e}");
206 continue;
207 }
208 };
209
210 let data = &buf[..n];
211
212 let guest_socket = {
215 let mut map = sessions.lock().await;
216
217 if let Some(session) = map.get_mut(&peer_addr) {
218 session.last_active = Instant::now();
219 Some(Arc::clone(&session.guest_socket))
220 } else {
221 None
222 }
223 };
224
225 if let Some(socket) = guest_socket {
226 if let Err(e) = socket.send_to(data, guest_addr).await {
227 tracing::debug!(%peer_addr, "UDP send to guest failed: {e}");
228 }
229 continue;
230 }
231
232 if sessions.lock().await.len() >= MAX_UDP_SESSIONS {
234 tracing::warn!(%peer_addr, "UDP session limit reached, dropping datagram");
235 continue;
236 }
237
238 let bind_addr = if guest_addr.is_ipv6() {
240 "[::]:0"
241 } else {
242 "0.0.0.0:0"
243 };
244 let guest_socket = match UdpSocket::bind(bind_addr).await {
245 Ok(s) => Arc::new(s),
246 Err(e) => {
247 tracing::warn!(%peer_addr, "failed to bind guest UDP socket: {e}");
248 continue;
249 }
250 };
251
252 if let Err(e) = guest_socket.send_to(data, guest_addr).await {
254 tracing::debug!(%peer_addr, "UDP send to guest failed: {e}");
255 continue;
256 }
257
258 sessions.lock().await.insert(
260 peer_addr,
261 UdpSession {
262 guest_socket: Arc::clone(&guest_socket),
263 last_active: Instant::now(),
264 },
265 );
266
267 let host_socket_clone = Arc::clone(&host_socket);
269 let sessions_clone = Arc::clone(&sessions);
270
271 tokio::spawn(async move {
272 let mut resp_buf = [0u8; 65535];
273 loop {
274 let recv_result = tokio::time::timeout(
275 UDP_SESSION_TIMEOUT,
276 guest_socket.recv_from(&mut resp_buf),
277 )
278 .await;
279
280 match recv_result {
281 Ok(Ok((n, _from))) => {
282 if let Err(e) = host_socket_clone.send_to(&resp_buf[..n], peer_addr).await {
283 tracing::debug!(%peer_addr, "UDP send to peer failed: {e}");
284 break;
285 }
286 if let Some(session) = sessions_clone.lock().await.get_mut(&peer_addr) {
287 session.last_active = Instant::now();
288 }
289 }
290 Ok(Err(e)) => {
291 tracing::debug!(%peer_addr, "UDP recv from guest failed: {e}");
292 break;
293 }
294 Err(_timeout) => {
295 tracing::debug!(%peer_addr, "UDP session timed out");
296 sessions_clone.lock().await.remove(&peer_addr);
297 break;
298 }
299 }
300 }
301
302 sessions_clone.lock().await.remove(&peer_addr);
303 });
304 }
305}
306
307fn resolve_guest_ip(
308 host_bind: IpAddr,
309 guest_addresses: &GuestAddresses,
310) -> std::io::Result<IpAddr> {
311 match host_bind {
312 IpAddr::V4(_) => guest_addresses
313 .ipv4
314 .map(IpAddr::V4)
315 .or_else(|| guest_addresses.ipv6.map(IpAddr::V6)),
316 IpAddr::V6(_) => guest_addresses
317 .ipv6
318 .map(IpAddr::V6)
319 .or_else(|| guest_addresses.ipv4.map(IpAddr::V4)),
320 }
321 .ok_or_else(|| {
322 std::io::Error::other(format!(
323 "no guest address available for published port bind family {host_bind}"
324 ))
325 })
326}
327
328#[cfg(test)]
333mod tests {
334 use std::net::TcpListener as StdTcpListener;
335
336 use super::*;
337
338 #[tokio::test]
339 async fn test_start_does_not_leak_earlier_listener_on_later_bind_failure() {
340 let reserved = StdTcpListener::bind(("127.0.0.1", 0)).unwrap();
341 let reserved_port = reserved.local_addr().unwrap().port();
342 let first = StdTcpListener::bind(("127.0.0.1", 0)).unwrap();
343 let first_port = first.local_addr().unwrap().port();
344 drop(first);
345
346 let ports = vec![
347 PublishedPort {
348 host_port: first_port,
349 guest_port: 8080,
350 protocol: PortProtocol::Tcp,
351 host_bind: IpAddr::V4(Ipv4Addr::LOCALHOST),
352 },
353 PublishedPort {
354 host_port: reserved_port,
355 guest_port: 8081,
356 protocol: PortProtocol::Tcp,
357 host_bind: IpAddr::V4(Ipv4Addr::LOCALHOST),
358 },
359 ];
360
361 assert!(
362 PortPublisher::start(&ports, Some(Ipv4Addr::new(100, 96, 0, 2)), None)
363 .await
364 .is_err()
365 );
366
367 StdTcpListener::bind(("127.0.0.1", first_port)).unwrap();
368 }
369}