1use std::collections::HashMap;
8use std::net::{IpAddr, Ipv4Addr, SocketAddr};
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11
12use bytes::Bytes;
13use smoltcp::wire::{
14 EthernetAddress, EthernetFrame, EthernetProtocol, EthernetRepr, IpProtocol, Ipv4Packet,
15 Ipv6Packet, UdpPacket,
16};
17use tokio::net::UdpSocket;
18use tokio::sync::mpsc;
19
20use crate::shared::SharedState;
21
22const SESSION_TIMEOUT: Duration = Duration::from_secs(60);
28
29const OUTBOUND_CHANNEL_CAPACITY: usize = 64;
31
32const RECV_BUF_SIZE: usize = 4096;
36
37const ETH_HDR_LEN: usize = 14;
39
40const IPV4_HDR_LEN: usize = 20;
42
43const UDP_HDR_LEN: usize = 8;
45
46pub struct UdpRelay {
59 shared: Arc<SharedState>,
60 sessions: HashMap<(SocketAddr, SocketAddr), UdpSession>,
61 gateway_mac: EthernetAddress,
62 guest_mac: EthernetAddress,
63 tokio_handle: tokio::runtime::Handle,
64}
65
66struct UdpSession {
68 outbound_tx: mpsc::Sender<Bytes>,
70 last_active: Instant,
72}
73
74impl UdpRelay {
79 pub fn new(
81 shared: Arc<SharedState>,
82 gateway_mac: [u8; 6],
83 guest_mac: [u8; 6],
84 tokio_handle: tokio::runtime::Handle,
85 ) -> Self {
86 Self {
87 shared,
88 sessions: HashMap::new(),
89 gateway_mac: EthernetAddress(gateway_mac),
90 guest_mac: EthernetAddress(guest_mac),
91 tokio_handle,
92 }
93 }
94
95 pub fn relay_outbound(&mut self, frame: &[u8], src: SocketAddr, dst: SocketAddr) {
100 let Some(payload) = extract_udp_payload(frame) else {
102 return;
103 };
104
105 let key = (src, dst);
106
107 if self
109 .sessions
110 .get(&key)
111 .is_none_or(|s| s.last_active.elapsed() > SESSION_TIMEOUT)
112 {
113 self.sessions.remove(&key);
114 if let Some(session) = self.create_session(src, dst) {
115 self.sessions.insert(key, session);
116 } else {
117 return;
118 }
119 }
120
121 if let Some(session) = self.sessions.get_mut(&key) {
122 session.last_active = Instant::now();
123 let _ = session
124 .outbound_tx
125 .try_send(Bytes::copy_from_slice(payload));
126 }
127 }
128
129 pub fn cleanup_expired(&mut self) {
131 self.sessions
132 .retain(|_, session| session.last_active.elapsed() <= SESSION_TIMEOUT);
133 }
134}
135
136impl UdpRelay {
137 fn create_session(&self, guest_src: SocketAddr, guest_dst: SocketAddr) -> Option<UdpSession> {
139 let (outbound_tx, outbound_rx) = mpsc::channel(OUTBOUND_CHANNEL_CAPACITY);
140
141 let shared = self.shared.clone();
142 let gateway_mac = self.gateway_mac;
143 let guest_mac = self.guest_mac;
144
145 self.tokio_handle.spawn(async move {
146 if let Err(e) = udp_relay_task(
147 outbound_rx,
148 guest_src,
149 guest_dst,
150 shared,
151 gateway_mac,
152 guest_mac,
153 )
154 .await
155 {
156 tracing::debug!(
157 guest_src = %guest_src,
158 guest_dst = %guest_dst,
159 error = %e,
160 "UDP relay task ended",
161 );
162 }
163 });
164
165 Some(UdpSession {
166 outbound_tx,
167 last_active: Instant::now(),
168 })
169 }
170}
171
172async fn udp_relay_task(
178 mut outbound_rx: mpsc::Receiver<Bytes>,
179 guest_src: SocketAddr,
180 guest_dst: SocketAddr,
181 shared: Arc<SharedState>,
182 gateway_mac: EthernetAddress,
183 guest_mac: EthernetAddress,
184) -> std::io::Result<()> {
185 let bind_addr: SocketAddr = match guest_dst {
187 SocketAddr::V4(_) => (Ipv4Addr::UNSPECIFIED, 0u16).into(),
188 SocketAddr::V6(_) => (std::net::Ipv6Addr::UNSPECIFIED, 0u16).into(),
189 };
190 let socket = UdpSocket::bind(bind_addr).await?;
191 socket.connect(guest_dst).await?;
194
195 let mut recv_buf = vec![0u8; RECV_BUF_SIZE];
196 let timeout = SESSION_TIMEOUT;
197
198 loop {
199 tokio::select! {
200 data = outbound_rx.recv() => {
202 match data {
203 Some(payload) => {
204 let _ = socket.send(&payload).await;
205 }
206 None => break,
208 }
209 }
210
211 result = socket.recv(&mut recv_buf) => {
213 match result {
214 Ok(n) => {
215 if let Some(frame) = construct_udp_response(
216 guest_dst,
217 guest_src,
218 &recv_buf[..n],
219 gateway_mac,
220 guest_mac,
221 ) {
222 let _ = shared.rx_ring.push(frame);
223 shared.rx_wake.wake();
224 }
225 }
226 Err(e) => {
227 tracing::debug!(error = %e, "UDP relay recv failed");
228 break;
229 }
230 }
231 }
232
233 () = tokio::time::sleep(timeout) => {
235 break;
236 }
237 }
238 }
239
240 Ok(())
241}
242
243fn construct_udp_response(
247 src: SocketAddr,
248 dst: SocketAddr,
249 payload: &[u8],
250 gateway_mac: EthernetAddress,
251 guest_mac: EthernetAddress,
252) -> Option<Vec<u8>> {
253 match (src.ip(), dst.ip()) {
254 (IpAddr::V4(src_ip), IpAddr::V4(dst_ip)) => Some(construct_udp_response_v4(
255 src_ip,
256 src.port(),
257 dst_ip,
258 dst.port(),
259 payload,
260 gateway_mac,
261 guest_mac,
262 )),
263 (IpAddr::V6(src_ip), IpAddr::V6(dst_ip)) => Some(construct_udp_response_v6(
264 src_ip,
265 src.port(),
266 dst_ip,
267 dst.port(),
268 payload,
269 gateway_mac,
270 guest_mac,
271 )),
272 _ => None, }
274}
275
276fn construct_udp_response_v4(
278 src_ip: Ipv4Addr,
279 src_port: u16,
280 dst_ip: Ipv4Addr,
281 dst_port: u16,
282 payload: &[u8],
283 gateway_mac: EthernetAddress,
284 guest_mac: EthernetAddress,
285) -> Vec<u8> {
286 let udp_len = UDP_HDR_LEN + payload.len();
287 let ip_total_len = IPV4_HDR_LEN + udp_len;
288 let frame_len = ETH_HDR_LEN + ip_total_len;
289 let mut buf = vec![0u8; frame_len];
290
291 let eth_repr = EthernetRepr {
293 src_addr: gateway_mac,
294 dst_addr: guest_mac,
295 ethertype: EthernetProtocol::Ipv4,
296 };
297 let mut eth_frame = EthernetFrame::new_unchecked(&mut buf);
298 eth_repr.emit(&mut eth_frame);
299
300 let ip_buf = &mut buf[ETH_HDR_LEN..];
302 let mut ip_pkt = Ipv4Packet::new_unchecked(ip_buf);
303 ip_pkt.set_version(4);
304 ip_pkt.set_header_len(20);
305 ip_pkt.set_total_len(ip_total_len as u16);
306 ip_pkt.clear_flags();
307 ip_pkt.set_dont_frag(true);
308 ip_pkt.set_hop_limit(64);
309 ip_pkt.set_next_header(IpProtocol::Udp);
310 ip_pkt.set_src_addr(src_ip);
311 ip_pkt.set_dst_addr(dst_ip);
312 ip_pkt.fill_checksum();
313
314 let udp_buf = &mut buf[ETH_HDR_LEN + IPV4_HDR_LEN..];
316 let mut udp_pkt = UdpPacket::new_unchecked(udp_buf);
317 udp_pkt.set_src_port(src_port);
318 udp_pkt.set_dst_port(dst_port);
319 udp_pkt.set_len(udp_len as u16);
320 udp_pkt.set_checksum(0); udp_pkt.payload_mut()[..payload.len()].copy_from_slice(payload);
322
323 buf
324}
325
326fn construct_udp_response_v6(
328 src_ip: std::net::Ipv6Addr,
329 src_port: u16,
330 dst_ip: std::net::Ipv6Addr,
331 dst_port: u16,
332 payload: &[u8],
333 gateway_mac: EthernetAddress,
334 guest_mac: EthernetAddress,
335) -> Vec<u8> {
336 let udp_len = UDP_HDR_LEN + payload.len();
337 let ipv6_hdr_len = 40;
338 let frame_len = ETH_HDR_LEN + ipv6_hdr_len + udp_len;
339 let mut buf = vec![0u8; frame_len];
340
341 let eth_repr = EthernetRepr {
343 src_addr: gateway_mac,
344 dst_addr: guest_mac,
345 ethertype: EthernetProtocol::Ipv6,
346 };
347 let mut eth_frame = EthernetFrame::new_unchecked(&mut buf);
348 eth_repr.emit(&mut eth_frame);
349
350 let ip_buf = &mut buf[ETH_HDR_LEN..];
352 let mut ip_pkt = Ipv6Packet::new_unchecked(ip_buf);
353 ip_pkt.set_version(6);
354 ip_pkt.set_payload_len(udp_len as u16);
355 ip_pkt.set_next_header(IpProtocol::Udp);
356 ip_pkt.set_hop_limit(64);
357 ip_pkt.set_src_addr(src_ip);
358 ip_pkt.set_dst_addr(dst_ip);
359
360 let udp_buf = &mut buf[ETH_HDR_LEN + ipv6_hdr_len..];
362 let mut udp_pkt = UdpPacket::new_unchecked(udp_buf);
363 udp_pkt.set_src_port(src_port);
364 udp_pkt.set_dst_port(dst_port);
365 udp_pkt.set_len(udp_len as u16);
366 udp_pkt.payload_mut()[..payload.len()].copy_from_slice(payload);
369 udp_pkt.fill_checksum(
372 &smoltcp::wire::IpAddress::from(src_ip),
373 &smoltcp::wire::IpAddress::from(dst_ip),
374 );
375
376 buf
377}
378
379fn extract_udp_payload(frame: &[u8]) -> Option<&[u8]> {
381 let eth = EthernetFrame::new_checked(frame).ok()?;
382 match eth.ethertype() {
383 EthernetProtocol::Ipv4 => {
384 let ipv4 = Ipv4Packet::new_checked(eth.payload()).ok()?;
385 let udp = UdpPacket::new_checked(ipv4.payload()).ok()?;
386 Some(udp.payload())
387 }
388 EthernetProtocol::Ipv6 => {
389 let ipv6 = Ipv6Packet::new_checked(eth.payload()).ok()?;
390 let udp = UdpPacket::new_checked(ipv6.payload()).ok()?;
391 Some(udp.payload())
392 }
393 _ => None,
394 }
395}
396
397#[cfg(test)]
402mod tests {
403 use super::*;
404
405 #[test]
406 fn construct_v4_response_has_correct_structure() {
407 let payload = b"hello";
408 let frame = construct_udp_response_v4(
409 Ipv4Addr::new(8, 8, 8, 8),
410 53,
411 Ipv4Addr::new(100, 96, 0, 2),
412 12345,
413 payload,
414 EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]),
415 EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x02]),
416 );
417
418 assert_eq!(frame.len(), ETH_HDR_LEN + IPV4_HDR_LEN + UDP_HDR_LEN + 5);
419
420 let eth = EthernetFrame::new_checked(&frame).unwrap();
422 assert_eq!(eth.ethertype(), EthernetProtocol::Ipv4);
423 assert_eq!(
424 eth.dst_addr(),
425 EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x02])
426 );
427
428 let ipv4 = Ipv4Packet::new_checked(eth.payload()).unwrap();
429 assert_eq!(Ipv4Addr::from(ipv4.src_addr()), Ipv4Addr::new(8, 8, 8, 8));
430 assert_eq!(
431 Ipv4Addr::from(ipv4.dst_addr()),
432 Ipv4Addr::new(100, 96, 0, 2)
433 );
434 assert_eq!(ipv4.next_header(), IpProtocol::Udp);
435
436 let udp = UdpPacket::new_checked(ipv4.payload()).unwrap();
437 assert_eq!(udp.src_port(), 53);
438 assert_eq!(udp.dst_port(), 12345);
439 assert_eq!(udp.payload(), b"hello");
440 }
441
442 #[test]
443 fn construct_v6_response_has_correct_structure() {
444 let payload = b"hello ipv6";
445 let src = "2001:db8::1".parse::<std::net::Ipv6Addr>().unwrap();
446 let dst = "fd42:6d73:62::2".parse::<std::net::Ipv6Addr>().unwrap();
447 let frame = construct_udp_response_v6(
448 src,
449 53,
450 dst,
451 12345,
452 payload,
453 EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]),
454 EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x02]),
455 );
456
457 let ipv6_hdr_len = 40;
458 assert_eq!(
459 frame.len(),
460 ETH_HDR_LEN + ipv6_hdr_len + UDP_HDR_LEN + payload.len()
461 );
462
463 let eth = EthernetFrame::new_checked(&frame).unwrap();
465 assert_eq!(eth.ethertype(), EthernetProtocol::Ipv6);
466
467 let ipv6 = Ipv6Packet::new_checked(eth.payload()).unwrap();
468 assert_eq!(ipv6.next_header(), IpProtocol::Udp);
469
470 let udp = UdpPacket::new_checked(ipv6.payload()).unwrap();
471 assert_eq!(udp.src_port(), 53);
472 assert_eq!(udp.dst_port(), 12345);
473 assert_eq!(udp.payload(), b"hello ipv6");
474 assert_ne!(udp.checksum(), 0, "IPv6 UDP checksum must not be zero");
476 assert!(
478 udp.verify_checksum(
479 &smoltcp::wire::IpAddress::from(src),
480 &smoltcp::wire::IpAddress::from(dst),
481 ),
482 "IPv6 UDP checksum must be valid"
483 );
484 }
485
486 #[test]
487 fn extract_payload_from_v6_udp_frame() {
488 let src = "2001:db8::1".parse::<std::net::Ipv6Addr>().unwrap();
489 let dst = "fd42:6d73:62::2".parse::<std::net::Ipv6Addr>().unwrap();
490 let frame = construct_udp_response_v6(
491 src,
492 80,
493 dst,
494 54321,
495 b"v6 data",
496 EthernetAddress([0; 6]),
497 EthernetAddress([0; 6]),
498 );
499 let payload = extract_udp_payload(&frame).unwrap();
500 assert_eq!(payload, b"v6 data");
501 }
502
503 #[test]
504 fn extract_payload_from_v4_udp_frame() {
505 let frame = construct_udp_response_v4(
507 Ipv4Addr::new(1, 2, 3, 4),
508 80,
509 Ipv4Addr::new(10, 0, 0, 2),
510 54321,
511 b"test data",
512 EthernetAddress([0; 6]),
513 EthernetAddress([0; 6]),
514 );
515 let payload = extract_udp_payload(&frame).unwrap();
516 assert_eq!(payload, b"test data");
517 }
518}