ipstack_geph/stream/
udp.rs

1use crate::{
2    packet::{IpHeader, NetworkPacket, TransportHeader},
3    PacketReceiver, PacketSender, TTL,
4};
5use anyhow::Context;
6
7use bytes::Bytes;
8use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel, Ipv6Header, UdpHeader};
9
10use smol_timeout::TimeoutExt;
11use std::{net::SocketAddr, pin::Pin, time::Duration};
12
13#[derive(Debug)]
14pub struct IpStackUdpStream {
15    src_addr: SocketAddr,
16    dst_addr: SocketAddr,
17    stream_sender: PacketSender,
18    stream_receiver: Pin<Box<PacketReceiver>>,
19    pkt_sender: PacketSender,
20
21    udp_timeout: Duration,
22    mtu: u16,
23}
24
25impl IpStackUdpStream {
26    pub fn new(
27        src_addr: SocketAddr,
28        dst_addr: SocketAddr,
29
30        pkt_sender: PacketSender,
31        mtu: u16,
32        udp_timeout: Duration,
33    ) -> Self {
34        let (stream_sender, stream_receiver) = async_channel::bounded::<NetworkPacket>(10);
35
36        IpStackUdpStream {
37            src_addr,
38            dst_addr,
39            stream_sender,
40            stream_receiver: Box::pin(stream_receiver),
41            pkt_sender,
42
43            udp_timeout,
44            mtu,
45        }
46    }
47
48    pub async fn recv(&self) -> anyhow::Result<Bytes> {
49        Ok(self
50            .stream_receiver
51            .recv()
52            .timeout(self.udp_timeout)
53            .await
54            .context("timeout")??
55            .payload
56            .into())
57    }
58
59    pub async fn send(&self, bts: &[u8]) -> anyhow::Result<()> {
60        let packet = self
61            .create_rev_packet(TTL, bts.to_vec())
62            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
63        self.pkt_sender.send(packet).await?;
64        Ok(())
65    }
66
67    pub(crate) fn stream_sender(&self) -> PacketSender {
68        self.stream_sender.clone()
69    }
70
71    fn create_rev_packet(&self, ttl: u8, mut payload: Vec<u8>) -> anyhow::Result<NetworkPacket> {
72        const UHS: usize = 8; // udp header size is 8
73        match (self.dst_addr.ip(), self.src_addr.ip()) {
74            (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => {
75                let mut ip_h = Ipv4Header::new(0, ttl, IpNumber::UDP, dst.octets(), src.octets())?;
76                let line_buffer = self.mtu.saturating_sub((ip_h.header_len() + UHS) as u16);
77                payload.truncate(line_buffer as usize);
78                ip_h.set_payload_len(payload.len() + UHS)?;
79                let udp_header = UdpHeader::with_ipv4_checksum(
80                    self.dst_addr.port(),
81                    self.src_addr.port(),
82                    &ip_h,
83                    &payload,
84                )?;
85                Ok(NetworkPacket {
86                    ip: IpHeader::Ipv4(ip_h),
87                    transport: TransportHeader::Udp(udp_header),
88                    payload,
89                })
90            }
91            (std::net::IpAddr::V6(dst), std::net::IpAddr::V6(src)) => {
92                let mut ip_h = Ipv6Header {
93                    traffic_class: 0,
94                    flow_label: Ipv6FlowLabel::ZERO,
95                    payload_length: 0,
96                    next_header: IpNumber::UDP,
97                    hop_limit: ttl,
98                    source: dst.octets(),
99                    destination: src.octets(),
100                };
101                let line_buffer = self.mtu.saturating_sub((ip_h.header_len() + UHS) as u16);
102
103                payload.truncate(line_buffer as usize);
104
105                ip_h.payload_length = (payload.len() + UHS) as u16;
106                let udp_header = UdpHeader::with_ipv6_checksum(
107                    self.dst_addr.port(),
108                    self.src_addr.port(),
109                    &ip_h,
110                    &payload,
111                )?;
112                Ok(NetworkPacket {
113                    ip: IpHeader::Ipv6(ip_h),
114                    transport: TransportHeader::Udp(udp_header),
115                    payload,
116                })
117            }
118            _ => unreachable!(),
119        }
120    }
121
122    pub fn local_addr(&self) -> SocketAddr {
123        self.src_addr
124    }
125
126    pub fn peer_addr(&self) -> SocketAddr {
127        self.dst_addr
128    }
129}