1use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
9use std::sync::Arc;
10
11use smoltcp::iface::{Config, Interface, SocketSet};
12use smoltcp::time::Instant;
13use std::sync::atomic::Ordering;
14
15use smoltcp::wire::{
16 EthernetAddress, EthernetFrame, EthernetProtocol, HardwareAddress, IpAddress, IpCidr,
17 IpProtocol, Ipv4Packet, Ipv6Packet, TcpPacket, UdpPacket,
18};
19
20use crate::config::{DnsConfig, PublishedPort};
21use crate::conn::ConnectionTracker;
22use crate::device::SmoltcpDevice;
23use crate::dns::interceptor::DnsInterceptor;
24use crate::policy::{NetworkPolicy, Protocol};
25use crate::proxy;
26use crate::publisher::PortPublisher;
27use crate::shared::SharedState;
28use crate::tls::{proxy as tls_proxy, state::TlsState};
29use crate::udp_relay::UdpRelay;
30
31pub enum FrameAction {
42 TcpSyn { src: SocketAddr, dst: SocketAddr },
45
46 UdpRelay { src: SocketAddr, dst: SocketAddr },
49
50 Dns,
52
53 Passthrough,
56}
57
58pub struct PollLoopConfig {
61 pub gateway_mac: [u8; 6],
63 pub guest_mac: [u8; 6],
65 pub gateway_ipv4: Ipv4Addr,
67 pub guest_ipv4: Ipv4Addr,
69 pub gateway_ipv6: Ipv6Addr,
71 pub mtu: usize,
73}
74
75pub fn classify_frame(frame: &[u8]) -> FrameAction {
85 let Ok(eth) = EthernetFrame::new_checked(frame) else {
86 return FrameAction::Passthrough;
87 };
88
89 match eth.ethertype() {
90 EthernetProtocol::Ipv4 => classify_ipv4(eth.payload()),
91 EthernetProtocol::Ipv6 => classify_ipv6(eth.payload()),
92 _ => FrameAction::Passthrough, }
94}
95
96pub fn create_interface(device: &mut SmoltcpDevice, config: &PollLoopConfig) -> Interface {
103 let hw_addr = HardwareAddress::Ethernet(EthernetAddress(config.gateway_mac));
104 let iface_config = Config::new(hw_addr);
105 let mut iface = Interface::new(iface_config, device, smoltcp_now());
106
107 iface.update_ip_addrs(|addrs| {
109 addrs
110 .push(IpCidr::new(
111 IpAddress::Ipv4(config.gateway_ipv4),
112 30,
114 ))
115 .expect("failed to add gateway IPv4 address");
116 addrs
117 .push(IpCidr::new(IpAddress::Ipv6(config.gateway_ipv6), 64))
118 .expect("failed to add gateway IPv6 address");
119 });
120
121 iface
123 .routes_mut()
124 .add_default_ipv4_route(config.gateway_ipv4)
125 .expect("failed to add default IPv4 route");
126 iface
127 .routes_mut()
128 .add_default_ipv6_route(config.gateway_ipv6)
129 .expect("failed to add default IPv6 route");
130
131 iface.set_any_ip(true);
133
134 iface
135}
136
137#[allow(clippy::too_many_arguments)]
151pub fn smoltcp_poll_loop(
152 shared: Arc<SharedState>,
153 config: PollLoopConfig,
154 network_policy: NetworkPolicy,
155 dns_config: DnsConfig,
156 tls_state: Option<Arc<TlsState>>,
157 published_ports: Vec<PublishedPort>,
158 max_connections: Option<usize>,
159 tokio_handle: tokio::runtime::Handle,
160) {
161 let mut device = SmoltcpDevice::new(shared.clone(), config.mtu);
162 let mut iface = create_interface(&mut device, &config);
163 let mut sockets = SocketSet::new(vec![]);
164 let mut conn_tracker = ConnectionTracker::new(max_connections);
165
166 let mut dns_interceptor =
167 DnsInterceptor::new(&mut sockets, dns_config, shared.clone(), &tokio_handle);
168 let mut port_publisher = PortPublisher::new(&published_ports, config.guest_ipv4, &tokio_handle);
169 let mut udp_relay = UdpRelay::new(
170 shared.clone(),
171 config.gateway_mac,
172 config.guest_mac,
173 tokio_handle.clone(),
174 );
175
176 let mut last_cleanup = std::time::Instant::now();
178
179 let mut poll_fds = [
181 libc::pollfd {
182 fd: shared.tx_wake.as_raw_fd(),
183 events: libc::POLLIN,
184 revents: 0,
185 },
186 libc::pollfd {
187 fd: shared.proxy_wake.as_raw_fd(),
188 events: libc::POLLIN,
189 revents: 0,
190 },
191 ];
192
193 loop {
194 let now = smoltcp_now();
195
196 while let Some(frame) = device.stage_next_frame() {
198 match classify_frame(frame) {
199 FrameAction::TcpSyn { src, dst } => {
200 if network_policy
202 .evaluate_egress(dst, Protocol::Tcp)
203 .is_allow()
204 && !conn_tracker.has_socket_for(&src, &dst)
205 {
206 conn_tracker.create_tcp_socket(src, dst, &mut sockets);
207 }
208 iface.poll_ingress_single(now, &mut device, &mut sockets);
211 }
212
213 FrameAction::UdpRelay { src, dst } => {
214 if let Some(ref tls) = tls_state
217 && tls.config.intercepted_ports.contains(&dst.port())
218 && tls.config.block_quic_on_intercept
219 {
220 device.drop_staged_frame();
221 continue;
222 }
223
224 if network_policy.evaluate_egress(dst, Protocol::Udp).is_deny() {
226 device.drop_staged_frame();
227 continue;
228 }
229
230 udp_relay.relay_outbound(frame, src, dst);
231 device.drop_staged_frame();
232 }
233
234 FrameAction::Dns | FrameAction::Passthrough => {
235 iface.poll_ingress_single(now, &mut device, &mut sockets);
237 }
238 }
239 }
240
241 loop {
245 let result = iface.poll_egress(now, &mut device, &mut sockets);
246 if matches!(result, smoltcp::iface::PollResult::None) {
247 break;
248 }
249 }
250 iface.poll_maintenance(now);
251
252 if device.frames_emitted.swap(false, Ordering::Relaxed) {
255 shared.rx_wake.wake();
256 }
257
258 conn_tracker.relay_data(&mut sockets);
263 dns_interceptor.process(&mut sockets);
264
265 port_publisher.accept_inbound(&mut iface, &mut sockets, &shared, &tokio_handle);
267 port_publisher.relay_data(&mut sockets);
268
269 let new_conns = conn_tracker.take_new_connections(&mut sockets);
271 for conn in new_conns {
272 if let Some(ref tls_state) = tls_state
273 && tls_state
274 .config
275 .intercepted_ports
276 .contains(&conn.dst.port())
277 {
278 tls_proxy::spawn_tls_proxy(
280 &tokio_handle,
281 conn.dst,
282 conn.from_smoltcp,
283 conn.to_smoltcp,
284 shared.clone(),
285 tls_state.clone(),
286 );
287 continue;
288 }
289 proxy::spawn_tcp_proxy(
291 &tokio_handle,
292 conn.dst,
293 conn.from_smoltcp,
294 conn.to_smoltcp,
295 shared.clone(),
296 );
297 }
298
299 if last_cleanup.elapsed() >= std::time::Duration::from_secs(1) {
302 conn_tracker.cleanup_closed(&mut sockets);
303 port_publisher.cleanup_closed(&mut sockets);
304 udp_relay.cleanup_expired();
305 last_cleanup = std::time::Instant::now();
306 }
307
308 loop {
311 let result = iface.poll_egress(now, &mut device, &mut sockets);
312 if matches!(result, smoltcp::iface::PollResult::None) {
313 break;
314 }
315 }
316
317 if device.frames_emitted.swap(false, Ordering::Relaxed) {
319 shared.rx_wake.wake();
320 }
321
322 let timeout_ms = iface
323 .poll_delay(now, &sockets)
324 .map(|d| d.total_millis().min(i32::MAX as u64) as i32)
325 .unwrap_or(100); unsafe {
329 libc::poll(
330 poll_fds.as_mut_ptr(),
331 poll_fds.len() as libc::nfds_t,
332 timeout_ms,
333 );
334 }
335
336 if poll_fds[0].revents & libc::POLLIN != 0 {
338 shared.tx_wake.drain();
339 }
340 if poll_fds[1].revents & libc::POLLIN != 0 {
341 shared.proxy_wake.drain();
342 }
343 }
344}
345
346fn smoltcp_now() -> Instant {
356 static EPOCH: std::sync::OnceLock<std::time::Instant> = std::sync::OnceLock::new();
357 let epoch = EPOCH.get_or_init(std::time::Instant::now);
358 let elapsed = epoch.elapsed();
359 Instant::from_millis(elapsed.as_millis() as i64)
360}
361
362fn classify_ipv4(payload: &[u8]) -> FrameAction {
364 let Ok(ipv4) = Ipv4Packet::new_checked(payload) else {
365 return FrameAction::Passthrough;
366 };
367 classify_transport(
368 ipv4.next_header(),
369 ipv4.src_addr().into(),
370 ipv4.dst_addr().into(),
371 ipv4.payload(),
372 )
373}
374
375fn classify_ipv6(payload: &[u8]) -> FrameAction {
377 let Ok(ipv6) = Ipv6Packet::new_checked(payload) else {
378 return FrameAction::Passthrough;
379 };
380 classify_transport(
381 ipv6.next_header(),
382 ipv6.src_addr().into(),
383 ipv6.dst_addr().into(),
384 ipv6.payload(),
385 )
386}
387
388fn classify_transport(
390 protocol: IpProtocol,
391 src_ip: std::net::IpAddr,
392 dst_ip: std::net::IpAddr,
393 transport_payload: &[u8],
394) -> FrameAction {
395 match protocol {
396 IpProtocol::Tcp => {
397 let Ok(tcp) = TcpPacket::new_checked(transport_payload) else {
398 return FrameAction::Passthrough;
399 };
400 if tcp.syn() && !tcp.ack() {
401 FrameAction::TcpSyn {
402 src: SocketAddr::new(src_ip, tcp.src_port()),
403 dst: SocketAddr::new(dst_ip, tcp.dst_port()),
404 }
405 } else {
406 FrameAction::Passthrough
407 }
408 }
409 IpProtocol::Udp => {
410 let Ok(udp) = UdpPacket::new_checked(transport_payload) else {
411 return FrameAction::Passthrough;
412 };
413 if udp.dst_port() == 53 {
414 FrameAction::Dns
415 } else {
416 FrameAction::UdpRelay {
417 src: SocketAddr::new(src_ip, udp.src_port()),
418 dst: SocketAddr::new(dst_ip, udp.dst_port()),
419 }
420 }
421 }
422 _ => FrameAction::Passthrough, }
424}
425
426#[cfg(test)]
431mod tests {
432 use super::*;
433
434 fn build_tcp_syn_frame(
436 src_ip: [u8; 4],
437 dst_ip: [u8; 4],
438 src_port: u16,
439 dst_port: u16,
440 ) -> Vec<u8> {
441 let mut frame = vec![0u8; 14 + 20 + 20]; frame[12] = 0x08; frame[13] = 0x00;
446
447 let ip = &mut frame[14..34];
449 ip[0] = 0x45; let total_len = 40u16; ip[2..4].copy_from_slice(&total_len.to_be_bytes());
452 ip[6] = 0x40; ip[8] = 64; ip[9] = 6; ip[12..16].copy_from_slice(&src_ip);
456 ip[16..20].copy_from_slice(&dst_ip);
457
458 let tcp = &mut frame[34..54];
460 tcp[0..2].copy_from_slice(&src_port.to_be_bytes());
461 tcp[2..4].copy_from_slice(&dst_port.to_be_bytes());
462 tcp[12] = 0x50; tcp[13] = 0x02; frame
466 }
467
468 fn build_udp_frame(src_ip: [u8; 4], dst_ip: [u8; 4], src_port: u16, dst_port: u16) -> Vec<u8> {
470 let mut frame = vec![0u8; 14 + 20 + 8]; frame[12] = 0x08;
474 frame[13] = 0x00;
475
476 let ip = &mut frame[14..34];
478 ip[0] = 0x45;
479 let total_len = 28u16; ip[2..4].copy_from_slice(&total_len.to_be_bytes());
481 ip[8] = 64;
482 ip[9] = 17; ip[12..16].copy_from_slice(&src_ip);
484 ip[16..20].copy_from_slice(&dst_ip);
485
486 let udp = &mut frame[34..42];
488 udp[0..2].copy_from_slice(&src_port.to_be_bytes());
489 udp[2..4].copy_from_slice(&dst_port.to_be_bytes());
490 let udp_len = 8u16;
491 udp[4..6].copy_from_slice(&udp_len.to_be_bytes());
492
493 frame
494 }
495
496 #[test]
497 fn classify_tcp_syn() {
498 let frame = build_tcp_syn_frame([10, 0, 0, 2], [93, 184, 216, 34], 54321, 443);
499 match classify_frame(&frame) {
500 FrameAction::TcpSyn { src, dst } => {
501 assert_eq!(
502 src,
503 SocketAddr::new(Ipv4Addr::new(10, 0, 0, 2).into(), 54321)
504 );
505 assert_eq!(
506 dst,
507 SocketAddr::new(Ipv4Addr::new(93, 184, 216, 34).into(), 443)
508 );
509 }
510 _ => panic!("expected TcpSyn"),
511 }
512 }
513
514 #[test]
515 fn classify_tcp_ack_is_passthrough() {
516 let mut frame = build_tcp_syn_frame([10, 0, 0, 2], [93, 184, 216, 34], 54321, 443);
517 frame[34 + 13] = 0x10; assert!(matches!(classify_frame(&frame), FrameAction::Passthrough));
520 }
521
522 #[test]
523 fn classify_udp_dns() {
524 let frame = build_udp_frame([10, 0, 0, 2], [10, 0, 0, 1], 12345, 53);
525 assert!(matches!(classify_frame(&frame), FrameAction::Dns));
526 }
527
528 #[test]
529 fn classify_udp_non_dns() {
530 let frame = build_udp_frame([10, 0, 0, 2], [8, 8, 8, 8], 12345, 443);
531 match classify_frame(&frame) {
532 FrameAction::UdpRelay { src, dst } => {
533 assert_eq!(src.port(), 12345);
534 assert_eq!(dst.port(), 443);
535 }
536 _ => panic!("expected UdpRelay"),
537 }
538 }
539
540 #[test]
541 fn classify_arp_is_passthrough() {
542 let mut frame = vec![0u8; 42]; frame[12] = 0x08;
544 frame[13] = 0x06; assert!(matches!(classify_frame(&frame), FrameAction::Passthrough));
546 }
547
548 #[test]
549 fn classify_garbage_is_passthrough() {
550 assert!(matches!(classify_frame(&[]), FrameAction::Passthrough));
551 assert!(matches!(classify_frame(&[0; 5]), FrameAction::Passthrough));
552 }
553}