1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use etherparse::PacketBuilder;
5use tokio::sync::mpsc;
6
7use crate::buffer::BufferPool;
8use crate::stack::IpPacket;
9use crate::stack::{NetStackConfig, Packet};
10use crate::{error, trace};
11
12pub struct UdpPacket {
13 pub data: Packet,
14 pub src_addr: SocketAddr,
15 pub dst_addr: SocketAddr,
16}
17
18impl<T> From<(T, SocketAddr, SocketAddr)> for UdpPacket
19where
20 T: Into<Packet>,
21{
22 fn from((data, src_addr, dst_addr): (T, SocketAddr, SocketAddr)) -> Self {
23 UdpPacket {
24 data: data.into(),
25 src_addr,
26 dst_addr,
27 }
28 }
29}
30
31impl UdpPacket {
32 pub fn data(&self) -> &[u8] {
33 self.data.data()
34 }
35}
36
37pub struct UdpTunnel {
38 inbound: mpsc::Receiver<Packet>,
39 outbound: mpsc::Sender<Packet>,
40 buffer_pool: Arc<BufferPool>,
41 config: Arc<NetStackConfig>,
42}
43
44impl UdpTunnel {
45 pub fn new(
46 config: Arc<NetStackConfig>,
47 inbound: mpsc::Receiver<Packet>,
48 outbound: mpsc::Sender<Packet>,
49 buffer_pool: Arc<BufferPool>,
50 ) -> Self {
51 Self {
52 inbound,
53 outbound,
54 buffer_pool,
55 config,
56 }
57 }
58
59 pub fn split(self) -> (SplitRead, SplitWrite) {
60 let read = SplitRead { recv: self.inbound };
61 let write = SplitWrite {
62 config: self.config,
63 send: self.outbound,
64 buffer_pool: self.buffer_pool,
65 };
66 (read, write)
67 }
68}
69
70pub struct SplitRead {
71 recv: mpsc::Receiver<Packet>,
72}
73
74impl SplitRead {
75 pub async fn recv(&mut self) -> Option<UdpPacket> {
76 self.recv.recv().await.and_then(|data| {
77 let original_bytes = data.into_bytes();
78
79 let packet = match IpPacket::new_checked(&original_bytes) {
80 Ok(p) => p,
81 Err(_e) => {
82 error!("invalid IP packet: {_e}");
83 return None;
84 }
85 };
86
87 let src_ip = packet.src_addr();
88 let dst_ip = packet.dst_addr();
89 let ip_payload = packet.payload();
90
91 let udp_packet = match smoltcp::wire::UdpPacket::new_checked(ip_payload) {
92 Ok(p) => p,
93 Err(_e) => {
94 error!(
95 "invalid err: {_e}, src_ip: {src_ip}, dst_ip: {dst_ip}, \
96 payload: {ip_payload:?}"
97 );
98 return None;
99 }
100 };
101 let src_port = udp_packet.src_port();
102 let dst_port = udp_packet.dst_port();
103
104 let udp_payload_slice = udp_packet.payload();
105
106 let original_ptr = original_bytes.as_ptr() as usize;
107 let payload_ptr = udp_payload_slice.as_ptr() as usize;
108
109 let offset = payload_ptr - original_ptr;
110 let len = udp_payload_slice.len();
111
112 let payload_bytes = original_bytes.slice(offset..offset + len);
113
114 let src_addr = SocketAddr::new(src_ip, src_port);
115 let dst_addr = SocketAddr::new(dst_ip, dst_port);
116
117 trace!("created UDP socket for {src_addr} <-> {dst_addr}");
118
119 Some(UdpPacket {
120 data: Packet::new(payload_bytes),
121 src_addr,
122 dst_addr,
123 })
124 })
125 }
126}
127
128#[derive(Clone)]
129pub struct SplitWrite {
130 config: Arc<NetStackConfig>,
131 send: mpsc::Sender<Packet>,
132 buffer_pool: Arc<BufferPool>,
133}
134
135impl SplitWrite {
136 pub async fn send(&mut self, packet: UdpPacket) -> Result<(), std::io::Error> {
137 let ttl = self.config.ip_ttl;
138 let builder = match (packet.src_addr, packet.dst_addr) {
139 (SocketAddr::V4(src), SocketAddr::V4(dst)) => {
140 PacketBuilder::ipv4(src.ip().octets(), dst.ip().octets(), ttl)
141 .udp(src.port(), dst.port())
142 }
143 (SocketAddr::V6(src), SocketAddr::V6(dst)) => {
144 PacketBuilder::ipv6(src.ip().octets(), dst.ip().octets(), ttl)
145 .udp(src.port(), dst.port())
146 }
147 _ => {
148 return Err(std::io::Error::new(
149 std::io::ErrorKind::InvalidInput,
150 "UDP socket only supports IPv4 and IPv6",
151 ));
152 }
153 };
154
155 let mut buffer = self.buffer_pool.get(builder.size(packet.data.data().len()));
156 builder
157 .write(&mut buffer, packet.data.data())
158 .map_err(std::io::Error::other)?;
159 let final_bytes = buffer.split().freeze();
160
161 match self.send.send(Packet::new(final_bytes)).await {
162 Ok(()) => Ok(()),
163 Err(err) => Err(std::io::Error::other(format!("send error: {err}"))),
164 }
165 }
166}