1use std::collections::HashMap;
8use std::net::{IpAddr, Ipv4Addr, SocketAddr};
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11
12use bytes::Bytes;
13use smoltcp::wire::{
14 EthernetAddress, EthernetFrame, EthernetProtocol, EthernetRepr, IpProtocol, Ipv4Packet,
15 Ipv6Packet, UdpPacket,
16};
17use tokio::net::UdpSocket;
18use tokio::sync::mpsc;
19
20use crate::shared::SharedState;
21
22const SESSION_TIMEOUT: Duration = Duration::from_secs(60);
28
29const OUTBOUND_CHANNEL_CAPACITY: usize = 64;
31
32const RECV_BUF_SIZE: usize = 4096;
36
37const ETH_HDR_LEN: usize = 14;
39
40const IPV4_HDR_LEN: usize = 20;
42
43const UDP_HDR_LEN: usize = 8;
45
46pub struct UdpRelay {
59 shared: Arc<SharedState>,
60 sessions: HashMap<(SocketAddr, SocketAddr), UdpSession>,
61 gateway_mac: EthernetAddress,
62 guest_mac: EthernetAddress,
63 tokio_handle: tokio::runtime::Handle,
64}
65
66struct UdpSession {
68 outbound_tx: mpsc::Sender<Bytes>,
70 last_active: Instant,
72}
73
74impl UdpRelay {
79 pub fn new(
89 shared: Arc<SharedState>,
90 gateway_mac: [u8; 6],
91 guest_mac: [u8; 6],
92 tokio_handle: tokio::runtime::Handle,
93 ) -> Self {
94 Self {
95 shared,
96 sessions: HashMap::new(),
97 gateway_mac: EthernetAddress(gateway_mac),
98 guest_mac: EthernetAddress(guest_mac),
99 tokio_handle,
100 }
101 }
102
103 pub fn relay_outbound(
115 &mut self,
116 frame: &[u8],
117 src: SocketAddr,
118 guest_dst: SocketAddr,
119 host_dst: SocketAddr,
120 ) {
121 let Some(payload) = extract_udp_payload(frame) else {
123 return;
124 };
125
126 let key = (src, guest_dst);
127
128 if self
130 .sessions
131 .get(&key)
132 .is_none_or(|s| s.last_active.elapsed() > SESSION_TIMEOUT)
133 {
134 self.sessions.remove(&key);
135 if let Some(session) = self.create_session(src, guest_dst, host_dst) {
136 self.sessions.insert(key, session);
137 } else {
138 return;
139 }
140 }
141
142 if let Some(session) = self.sessions.get_mut(&key) {
143 session.last_active = Instant::now();
144 let _ = session
145 .outbound_tx
146 .try_send(Bytes::copy_from_slice(payload));
147 }
148 }
149
150 pub fn cleanup_expired(&mut self) {
152 self.sessions
153 .retain(|_, session| session.last_active.elapsed() <= SESSION_TIMEOUT);
154 }
155}
156
157impl UdpRelay {
158 fn create_session(
160 &self,
161 guest_src: SocketAddr,
162 guest_dst: SocketAddr,
163 host_dst: SocketAddr,
164 ) -> Option<UdpSession> {
165 let (outbound_tx, outbound_rx) = mpsc::channel(OUTBOUND_CHANNEL_CAPACITY);
166
167 let shared = self.shared.clone();
168 let gateway_mac = self.gateway_mac;
169 let guest_mac = self.guest_mac;
170
171 self.tokio_handle.spawn(async move {
172 if let Err(e) = udp_relay_task(
173 outbound_rx,
174 guest_src,
175 guest_dst,
176 host_dst,
177 shared,
178 gateway_mac,
179 guest_mac,
180 )
181 .await
182 {
183 tracing::debug!(
184 guest_src = %guest_src,
185 guest_dst = %guest_dst,
186 error = %e,
187 "UDP relay task ended",
188 );
189 }
190 });
191
192 Some(UdpSession {
193 outbound_tx,
194 last_active: Instant::now(),
195 })
196 }
197}
198
199#[allow(clippy::too_many_arguments)]
232async fn udp_relay_task(
233 mut outbound_rx: mpsc::Receiver<Bytes>,
234 guest_src: SocketAddr,
235 guest_dst: SocketAddr,
236 host_dst: SocketAddr,
237 shared: Arc<SharedState>,
238 gateway_mac: EthernetAddress,
239 guest_mac: EthernetAddress,
240) -> std::io::Result<()> {
241 let bind_addr: SocketAddr = match host_dst {
243 SocketAddr::V4(_) => (Ipv4Addr::UNSPECIFIED, 0u16).into(),
244 SocketAddr::V6(_) => (std::net::Ipv6Addr::UNSPECIFIED, 0u16).into(),
245 };
246 let socket = UdpSocket::bind(bind_addr).await?;
247 socket.connect(host_dst).await?;
250
251 let mut recv_buf = vec![0u8; RECV_BUF_SIZE];
252 let timeout = SESSION_TIMEOUT;
253
254 loop {
255 tokio::select! {
256 data = outbound_rx.recv() => {
258 match data {
259 Some(payload) => {
260 let _ = socket.send(&payload).await;
261 }
262 None => break,
264 }
265 }
266
267 result = socket.recv(&mut recv_buf) => {
269 match result {
270 Ok(n) => {
271 if let Some(frame) = construct_udp_response(
272 guest_dst,
273 guest_src,
274 &recv_buf[..n],
275 gateway_mac,
276 guest_mac,
277 ) && !shared.push_rx_frame_and_wake(frame) {
278 tracing::debug!("UDP relay response dropped because rx_ring is full");
279 }
280 }
281 Err(e) => {
282 tracing::debug!(error = %e, "UDP relay recv failed");
283 break;
284 }
285 }
286 }
287
288 () = tokio::time::sleep(timeout) => {
290 break;
291 }
292 }
293 }
294
295 Ok(())
296}
297
298pub(crate) fn construct_udp_response(
302 src: SocketAddr,
303 dst: SocketAddr,
304 payload: &[u8],
305 gateway_mac: EthernetAddress,
306 guest_mac: EthernetAddress,
307) -> Option<Vec<u8>> {
308 match (src.ip(), dst.ip()) {
309 (IpAddr::V4(src_ip), IpAddr::V4(dst_ip)) => Some(construct_udp_response_v4(
310 src_ip,
311 src.port(),
312 dst_ip,
313 dst.port(),
314 payload,
315 gateway_mac,
316 guest_mac,
317 )),
318 (IpAddr::V6(src_ip), IpAddr::V6(dst_ip)) => Some(construct_udp_response_v6(
319 src_ip,
320 src.port(),
321 dst_ip,
322 dst.port(),
323 payload,
324 gateway_mac,
325 guest_mac,
326 )),
327 _ => None, }
329}
330
331fn construct_udp_response_v4(
333 src_ip: Ipv4Addr,
334 src_port: u16,
335 dst_ip: Ipv4Addr,
336 dst_port: u16,
337 payload: &[u8],
338 gateway_mac: EthernetAddress,
339 guest_mac: EthernetAddress,
340) -> Vec<u8> {
341 let udp_len = UDP_HDR_LEN + payload.len();
342 let ip_total_len = IPV4_HDR_LEN + udp_len;
343 let frame_len = ETH_HDR_LEN + ip_total_len;
344 let mut buf = vec![0u8; frame_len];
345
346 let eth_repr = EthernetRepr {
348 src_addr: gateway_mac,
349 dst_addr: guest_mac,
350 ethertype: EthernetProtocol::Ipv4,
351 };
352 let mut eth_frame = EthernetFrame::new_unchecked(&mut buf);
353 eth_repr.emit(&mut eth_frame);
354
355 let ip_buf = &mut buf[ETH_HDR_LEN..];
357 let mut ip_pkt = Ipv4Packet::new_unchecked(ip_buf);
358 ip_pkt.set_version(4);
359 ip_pkt.set_header_len(20);
360 ip_pkt.set_total_len(ip_total_len as u16);
361 ip_pkt.clear_flags();
362 ip_pkt.set_dont_frag(true);
363 ip_pkt.set_hop_limit(64);
364 ip_pkt.set_next_header(IpProtocol::Udp);
365 ip_pkt.set_src_addr(src_ip);
366 ip_pkt.set_dst_addr(dst_ip);
367 ip_pkt.fill_checksum();
368
369 let udp_buf = &mut buf[ETH_HDR_LEN + IPV4_HDR_LEN..];
371 let mut udp_pkt = UdpPacket::new_unchecked(udp_buf);
372 udp_pkt.set_src_port(src_port);
373 udp_pkt.set_dst_port(dst_port);
374 udp_pkt.set_len(udp_len as u16);
375 udp_pkt.set_checksum(0); udp_pkt.payload_mut()[..payload.len()].copy_from_slice(payload);
377
378 buf
379}
380
381fn construct_udp_response_v6(
383 src_ip: std::net::Ipv6Addr,
384 src_port: u16,
385 dst_ip: std::net::Ipv6Addr,
386 dst_port: u16,
387 payload: &[u8],
388 gateway_mac: EthernetAddress,
389 guest_mac: EthernetAddress,
390) -> Vec<u8> {
391 let udp_len = UDP_HDR_LEN + payload.len();
392 let ipv6_hdr_len = 40;
393 let frame_len = ETH_HDR_LEN + ipv6_hdr_len + udp_len;
394 let mut buf = vec![0u8; frame_len];
395
396 let eth_repr = EthernetRepr {
398 src_addr: gateway_mac,
399 dst_addr: guest_mac,
400 ethertype: EthernetProtocol::Ipv6,
401 };
402 let mut eth_frame = EthernetFrame::new_unchecked(&mut buf);
403 eth_repr.emit(&mut eth_frame);
404
405 let ip_buf = &mut buf[ETH_HDR_LEN..];
407 let mut ip_pkt = Ipv6Packet::new_unchecked(ip_buf);
408 ip_pkt.set_version(6);
409 ip_pkt.set_payload_len(udp_len as u16);
410 ip_pkt.set_next_header(IpProtocol::Udp);
411 ip_pkt.set_hop_limit(64);
412 ip_pkt.set_src_addr(src_ip);
413 ip_pkt.set_dst_addr(dst_ip);
414
415 let udp_buf = &mut buf[ETH_HDR_LEN + ipv6_hdr_len..];
417 let mut udp_pkt = UdpPacket::new_unchecked(udp_buf);
418 udp_pkt.set_src_port(src_port);
419 udp_pkt.set_dst_port(dst_port);
420 udp_pkt.set_len(udp_len as u16);
421 udp_pkt.payload_mut()[..payload.len()].copy_from_slice(payload);
424 udp_pkt.fill_checksum(
427 &smoltcp::wire::IpAddress::from(src_ip),
428 &smoltcp::wire::IpAddress::from(dst_ip),
429 );
430
431 buf
432}
433
434pub(crate) fn extract_udp_payload(frame: &[u8]) -> Option<&[u8]> {
436 let eth = EthernetFrame::new_checked(frame).ok()?;
437 match eth.ethertype() {
438 EthernetProtocol::Ipv4 => {
439 let ipv4 = Ipv4Packet::new_checked(eth.payload()).ok()?;
440 let udp = UdpPacket::new_checked(ipv4.payload()).ok()?;
441 Some(udp.payload())
442 }
443 EthernetProtocol::Ipv6 => {
444 let ipv6 = Ipv6Packet::new_checked(eth.payload()).ok()?;
445 let udp = UdpPacket::new_checked(ipv6.payload()).ok()?;
446 Some(udp.payload())
447 }
448 _ => None,
449 }
450}
451
452#[cfg(test)]
457mod tests {
458 use super::*;
459
460 #[test]
461 fn construct_v4_response_has_correct_structure() {
462 let payload = b"hello";
463 let frame = construct_udp_response_v4(
464 Ipv4Addr::new(8, 8, 8, 8),
465 53,
466 Ipv4Addr::new(100, 96, 0, 2),
467 12345,
468 payload,
469 EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]),
470 EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x02]),
471 );
472
473 assert_eq!(frame.len(), ETH_HDR_LEN + IPV4_HDR_LEN + UDP_HDR_LEN + 5);
474
475 let eth = EthernetFrame::new_checked(&frame).unwrap();
477 assert_eq!(eth.ethertype(), EthernetProtocol::Ipv4);
478 assert_eq!(
479 eth.dst_addr(),
480 EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x02])
481 );
482
483 let ipv4 = Ipv4Packet::new_checked(eth.payload()).unwrap();
484 assert_eq!(ipv4.src_addr(), Ipv4Addr::new(8, 8, 8, 8));
485 assert_eq!(ipv4.dst_addr(), Ipv4Addr::new(100, 96, 0, 2));
486 assert_eq!(ipv4.next_header(), IpProtocol::Udp);
487
488 let udp = UdpPacket::new_checked(ipv4.payload()).unwrap();
489 assert_eq!(udp.src_port(), 53);
490 assert_eq!(udp.dst_port(), 12345);
491 assert_eq!(udp.payload(), b"hello");
492 }
493
494 #[test]
495 fn construct_v6_response_has_correct_structure() {
496 let payload = b"hello ipv6";
497 let src = "2001:db8::1".parse::<std::net::Ipv6Addr>().unwrap();
498 let dst = "fd42:6d73:62::2".parse::<std::net::Ipv6Addr>().unwrap();
499 let frame = construct_udp_response_v6(
500 src,
501 53,
502 dst,
503 12345,
504 payload,
505 EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]),
506 EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x02]),
507 );
508
509 let ipv6_hdr_len = 40;
510 assert_eq!(
511 frame.len(),
512 ETH_HDR_LEN + ipv6_hdr_len + UDP_HDR_LEN + payload.len()
513 );
514
515 let eth = EthernetFrame::new_checked(&frame).unwrap();
517 assert_eq!(eth.ethertype(), EthernetProtocol::Ipv6);
518
519 let ipv6 = Ipv6Packet::new_checked(eth.payload()).unwrap();
520 assert_eq!(ipv6.next_header(), IpProtocol::Udp);
521
522 let udp = UdpPacket::new_checked(ipv6.payload()).unwrap();
523 assert_eq!(udp.src_port(), 53);
524 assert_eq!(udp.dst_port(), 12345);
525 assert_eq!(udp.payload(), b"hello ipv6");
526 assert_ne!(udp.checksum(), 0, "IPv6 UDP checksum must not be zero");
528 assert!(
530 udp.verify_checksum(
531 &smoltcp::wire::IpAddress::from(src),
532 &smoltcp::wire::IpAddress::from(dst),
533 ),
534 "IPv6 UDP checksum must be valid"
535 );
536 }
537
538 #[test]
539 fn extract_payload_from_v6_udp_frame() {
540 let src = "2001:db8::1".parse::<std::net::Ipv6Addr>().unwrap();
541 let dst = "fd42:6d73:62::2".parse::<std::net::Ipv6Addr>().unwrap();
542 let frame = construct_udp_response_v6(
543 src,
544 80,
545 dst,
546 54321,
547 b"v6 data",
548 EthernetAddress([0; 6]),
549 EthernetAddress([0; 6]),
550 );
551 let payload = extract_udp_payload(&frame).unwrap();
552 assert_eq!(payload, b"v6 data");
553 }
554
555 #[test]
556 fn extract_payload_from_v4_udp_frame() {
557 let frame = construct_udp_response_v4(
559 Ipv4Addr::new(1, 2, 3, 4),
560 80,
561 Ipv4Addr::new(10, 0, 0, 2),
562 54321,
563 b"test data",
564 EthernetAddress([0; 6]),
565 EthernetAddress([0; 6]),
566 );
567 let payload = extract_udp_payload(&frame).unwrap();
568 assert_eq!(payload, b"test data");
569 }
570}