Skip to main content

ipstack_geph/stream/
udp.rs

1use crate::{
2    packet::{IpHeader, NetworkPacket, TransportHeader},
3    PacketReceiver, PacketSender, TTL,
4};
5use anyhow::Context;
6
7use bytes::Bytes;
8use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel, Ipv6Header, UdpHeader};
9
10use smol_timeout::TimeoutExt;
11use std::{net::SocketAddr, pin::Pin, time::Duration};
12
13#[derive(Debug)]
14pub struct IpStackUdpStream {
15    src_addr: SocketAddr,
16    dst_addr: SocketAddr,
17    stream_sender: PacketSender,
18    stream_receiver: Pin<Box<PacketReceiver>>,
19    pkt_sender: PacketSender,
20
21    udp_timeout: Duration,
22    mtu: u16,
23}
24
25impl IpStackUdpStream {
26    pub fn new(
27        src_addr: SocketAddr,
28        dst_addr: SocketAddr,
29
30        pkt_sender: PacketSender,
31        mtu: u16,
32        udp_timeout: Duration,
33    ) -> Self {
34        let (stream_sender, stream_receiver) = async_channel::bounded::<NetworkPacket>(10);
35
36        IpStackUdpStream {
37            src_addr,
38            dst_addr,
39            stream_sender,
40            stream_receiver: Box::pin(stream_receiver),
41            pkt_sender,
42
43            udp_timeout,
44            mtu,
45        }
46    }
47
48    pub async fn recv(&self) -> anyhow::Result<Bytes> {
49        Ok(self
50            .stream_receiver
51            .recv()
52            .timeout(self.udp_timeout)
53            .await
54            .context("timeout")??
55            .payload
56            .into())
57    }
58
59    pub async fn send(&self, bts: &[u8]) -> anyhow::Result<()> {
60        let packet = self
61            .create_rev_packet(TTL, bts.to_vec())
62            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
63        self.pkt_sender.send(packet).await?;
64        Ok(())
65    }
66
67    pub(crate) fn stream_sender(&self) -> PacketSender {
68        self.stream_sender.clone()
69    }
70
71    fn create_rev_packet(&self, ttl: u8, payload: Vec<u8>) -> anyhow::Result<NetworkPacket> {
72        const UHS: usize = 8; // udp header size is 8
73        match (self.dst_addr.ip(), self.src_addr.ip()) {
74            (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => {
75                let mut ip_h = Ipv4Header::new(0, ttl, IpNumber::UDP, dst.octets(), src.octets())?;
76                if self.mtu < (ip_h.header_len() + UHS) as u16 {
77                    anyhow::bail!("message too large");
78                }
79                let line_buffer = self.mtu.saturating_sub((ip_h.header_len() + UHS) as u16);
80                if payload.len() > line_buffer as usize {
81                    anyhow::bail!("message too large");
82                }
83                ip_h.set_payload_len(payload.len() + UHS)?;
84                let udp_header = UdpHeader::with_ipv4_checksum(
85                    self.dst_addr.port(),
86                    self.src_addr.port(),
87                    &ip_h,
88                    &payload,
89                )?;
90                Ok(NetworkPacket {
91                    ip: IpHeader::Ipv4(ip_h),
92                    transport: TransportHeader::Udp(udp_header),
93                    payload,
94                })
95            }
96            (std::net::IpAddr::V6(dst), std::net::IpAddr::V6(src)) => {
97                let mut ip_h = Ipv6Header {
98                    traffic_class: 0,
99                    flow_label: Ipv6FlowLabel::ZERO,
100                    payload_length: 0,
101                    next_header: IpNumber::UDP,
102                    hop_limit: ttl,
103                    source: dst.octets(),
104                    destination: src.octets(),
105                };
106                if self.mtu < (ip_h.header_len() + UHS) as u16 {
107                    anyhow::bail!("message too large");
108                }
109                let line_buffer = self.mtu.saturating_sub((ip_h.header_len() + UHS) as u16);
110
111                if payload.len() > line_buffer as usize {
112                    anyhow::bail!("message too large");
113                }
114
115                ip_h.payload_length = (payload.len() + UHS) as u16;
116                let udp_header = UdpHeader::with_ipv6_checksum(
117                    self.dst_addr.port(),
118                    self.src_addr.port(),
119                    &ip_h,
120                    &payload,
121                )?;
122                Ok(NetworkPacket {
123                    ip: IpHeader::Ipv6(ip_h),
124                    transport: TransportHeader::Udp(udp_header),
125                    payload,
126                })
127            }
128            _ => unreachable!(),
129        }
130    }
131
132    pub fn local_addr(&self) -> SocketAddr {
133        self.src_addr
134    }
135
136    pub fn peer_addr(&self) -> SocketAddr {
137        self.dst_addr
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144    use std::{
145        net::{IpAddr, Ipv4Addr, Ipv6Addr},
146        time::Duration,
147    };
148
149    #[test]
150    fn oversized_udp_datagram_returns_error() {
151        let (sender, _receiver) = async_channel::unbounded();
152        let stream = IpStackUdpStream::new(
153            SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 1000),
154            SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)), 2000),
155            sender,
156            28,
157            Duration::from_secs(60),
158        );
159
160        assert!(stream.create_rev_packet(TTL, vec![1]).is_err());
161        assert!(stream.create_rev_packet(TTL, Vec::new()).is_ok());
162    }
163
164    #[test]
165    fn oversized_ipv6_udp_datagram_returns_error() {
166        let (sender, _receiver) = async_channel::unbounded();
167        let stream = IpStackUdpStream::new(
168            SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 1000),
169            SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 2000),
170            sender,
171            48,
172            Duration::from_secs(60),
173        );
174
175        assert!(stream.create_rev_packet(TTL, vec![1]).is_err());
176        assert!(stream.create_rev_packet(TTL, Vec::new()).is_ok());
177    }
178}