Skip to main content

ombrac_netstack/
udp.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use etherparse::PacketBuilder;
5use tokio::sync::mpsc;
6
7use crate::buffer::BufferPool;
8use crate::stack::IpPacket;
9use crate::stack::{NetStackConfig, Packet};
10use crate::{error, trace};
11
12pub struct UdpPacket {
13    pub data: Packet,
14    pub src_addr: SocketAddr,
15    pub dst_addr: SocketAddr,
16}
17
18impl<T> From<(T, SocketAddr, SocketAddr)> for UdpPacket
19where
20    T: Into<Packet>,
21{
22    fn from((data, src_addr, dst_addr): (T, SocketAddr, SocketAddr)) -> Self {
23        UdpPacket {
24            data: data.into(),
25            src_addr,
26            dst_addr,
27        }
28    }
29}
30
31impl UdpPacket {
32    pub fn data(&self) -> &[u8] {
33        self.data.data()
34    }
35}
36
37pub struct UdpTunnel {
38    inbound: mpsc::Receiver<Packet>,
39    outbound: mpsc::Sender<Packet>,
40    buffer_pool: Arc<BufferPool>,
41    config: Arc<NetStackConfig>,
42}
43
44impl UdpTunnel {
45    pub fn new(
46        config: Arc<NetStackConfig>,
47        inbound: mpsc::Receiver<Packet>,
48        outbound: mpsc::Sender<Packet>,
49        buffer_pool: Arc<BufferPool>,
50    ) -> Self {
51        Self {
52            inbound,
53            outbound,
54            buffer_pool,
55            config,
56        }
57    }
58
59    pub fn split(self) -> (SplitRead, SplitWrite) {
60        let read = SplitRead { recv: self.inbound };
61        let write = SplitWrite {
62            config: self.config,
63            send: self.outbound,
64            buffer_pool: self.buffer_pool,
65        };
66        (read, write)
67    }
68}
69
70pub struct SplitRead {
71    recv: mpsc::Receiver<Packet>,
72}
73
74impl SplitRead {
75    pub async fn recv(&mut self) -> Option<UdpPacket> {
76        self.recv.recv().await.and_then(|data| {
77            let original_bytes = data.into_bytes();
78
79            let packet = match IpPacket::new_checked(&original_bytes) {
80                Ok(p) => p,
81                Err(_e) => {
82                    error!("invalid IP packet: {_e}");
83                    return None;
84                }
85            };
86
87            let src_ip = packet.src_addr();
88            let dst_ip = packet.dst_addr();
89            let ip_payload = packet.payload();
90
91            let udp_packet = match smoltcp::wire::UdpPacket::new_checked(ip_payload) {
92                Ok(p) => p,
93                Err(_e) => {
94                    error!(
95                        "invalid err: {_e}, src_ip: {src_ip}, dst_ip: {dst_ip}, \
96                         payload: {ip_payload:?}"
97                    );
98                    return None;
99                }
100            };
101            let src_port = udp_packet.src_port();
102            let dst_port = udp_packet.dst_port();
103
104            let udp_payload_slice = udp_packet.payload();
105
106            let original_ptr = original_bytes.as_ptr() as usize;
107            let payload_ptr = udp_payload_slice.as_ptr() as usize;
108
109            let offset = payload_ptr - original_ptr;
110            let len = udp_payload_slice.len();
111
112            let payload_bytes = original_bytes.slice(offset..offset + len);
113
114            let src_addr = SocketAddr::new(src_ip, src_port);
115            let dst_addr = SocketAddr::new(dst_ip, dst_port);
116
117            trace!("created UDP socket for {src_addr} <-> {dst_addr}");
118
119            Some(UdpPacket {
120                data: Packet::new(payload_bytes),
121                src_addr,
122                dst_addr,
123            })
124        })
125    }
126}
127
128#[derive(Clone)]
129pub struct SplitWrite {
130    config: Arc<NetStackConfig>,
131    send: mpsc::Sender<Packet>,
132    buffer_pool: Arc<BufferPool>,
133}
134
135impl SplitWrite {
136    pub async fn send(&mut self, packet: UdpPacket) -> Result<(), std::io::Error> {
137        let ttl = self.config.ip_ttl;
138        let builder = match (packet.src_addr, packet.dst_addr) {
139            (SocketAddr::V4(src), SocketAddr::V4(dst)) => {
140                PacketBuilder::ipv4(src.ip().octets(), dst.ip().octets(), ttl)
141                    .udp(src.port(), dst.port())
142            }
143            (SocketAddr::V6(src), SocketAddr::V6(dst)) => {
144                PacketBuilder::ipv6(src.ip().octets(), dst.ip().octets(), ttl)
145                    .udp(src.port(), dst.port())
146            }
147            _ => {
148                return Err(std::io::Error::new(
149                    std::io::ErrorKind::InvalidInput,
150                    "UDP socket only supports IPv4 and IPv6",
151                ));
152            }
153        };
154
155        let mut buffer = self.buffer_pool.get(builder.size(packet.data.data().len()));
156        builder
157            .write(&mut buffer, packet.data.data())
158            .map_err(std::io::Error::other)?;
159        let final_bytes = buffer.split().freeze();
160
161        match self.send.send(Packet::new(final_bytes)).await {
162            Ok(()) => Ok(()),
163            Err(err) => Err(std::io::Error::other(format!("send error: {err}"))),
164        }
165    }
166}