ipstack_geph/stream/
udp.rs1use crate::{
2 packet::{IpHeader, NetworkPacket, TransportHeader},
3 PacketReceiver, PacketSender, TTL,
4};
5use anyhow::Context;
6
7use bytes::Bytes;
8use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel, Ipv6Header, UdpHeader};
9
10use smol_timeout::TimeoutExt;
11use std::{net::SocketAddr, pin::Pin, time::Duration};
12
13#[derive(Debug)]
14pub struct IpStackUdpStream {
15 src_addr: SocketAddr,
16 dst_addr: SocketAddr,
17 stream_sender: PacketSender,
18 stream_receiver: Pin<Box<PacketReceiver>>,
19 pkt_sender: PacketSender,
20
21 udp_timeout: Duration,
22 mtu: u16,
23}
24
25impl IpStackUdpStream {
26 pub fn new(
27 src_addr: SocketAddr,
28 dst_addr: SocketAddr,
29
30 pkt_sender: PacketSender,
31 mtu: u16,
32 udp_timeout: Duration,
33 ) -> Self {
34 let (stream_sender, stream_receiver) = async_channel::bounded::<NetworkPacket>(10);
35
36 IpStackUdpStream {
37 src_addr,
38 dst_addr,
39 stream_sender,
40 stream_receiver: Box::pin(stream_receiver),
41 pkt_sender,
42
43 udp_timeout,
44 mtu,
45 }
46 }
47
48 pub async fn recv(&self) -> anyhow::Result<Bytes> {
49 Ok(self
50 .stream_receiver
51 .recv()
52 .timeout(self.udp_timeout)
53 .await
54 .context("timeout")??
55 .payload
56 .into())
57 }
58
59 pub async fn send(&self, bts: &[u8]) -> anyhow::Result<()> {
60 let packet = self
61 .create_rev_packet(TTL, bts.to_vec())
62 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
63 self.pkt_sender.send(packet).await?;
64 Ok(())
65 }
66
67 pub(crate) fn stream_sender(&self) -> PacketSender {
68 self.stream_sender.clone()
69 }
70
71 fn create_rev_packet(&self, ttl: u8, mut payload: Vec<u8>) -> anyhow::Result<NetworkPacket> {
72 const UHS: usize = 8; match (self.dst_addr.ip(), self.src_addr.ip()) {
74 (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => {
75 let mut ip_h = Ipv4Header::new(0, ttl, IpNumber::UDP, dst.octets(), src.octets())?;
76 let line_buffer = self.mtu.saturating_sub((ip_h.header_len() + UHS) as u16);
77 payload.truncate(line_buffer as usize);
78 ip_h.set_payload_len(payload.len() + UHS)?;
79 let udp_header = UdpHeader::with_ipv4_checksum(
80 self.dst_addr.port(),
81 self.src_addr.port(),
82 &ip_h,
83 &payload,
84 )?;
85 Ok(NetworkPacket {
86 ip: IpHeader::Ipv4(ip_h),
87 transport: TransportHeader::Udp(udp_header),
88 payload,
89 })
90 }
91 (std::net::IpAddr::V6(dst), std::net::IpAddr::V6(src)) => {
92 let mut ip_h = Ipv6Header {
93 traffic_class: 0,
94 flow_label: Ipv6FlowLabel::ZERO,
95 payload_length: 0,
96 next_header: IpNumber::UDP,
97 hop_limit: ttl,
98 source: dst.octets(),
99 destination: src.octets(),
100 };
101 let line_buffer = self.mtu.saturating_sub((ip_h.header_len() + UHS) as u16);
102
103 payload.truncate(line_buffer as usize);
104
105 ip_h.payload_length = (payload.len() + UHS) as u16;
106 let udp_header = UdpHeader::with_ipv6_checksum(
107 self.dst_addr.port(),
108 self.src_addr.port(),
109 &ip_h,
110 &payload,
111 )?;
112 Ok(NetworkPacket {
113 ip: IpHeader::Ipv6(ip_h),
114 transport: TransportHeader::Udp(udp_header),
115 payload,
116 })
117 }
118 _ => unreachable!(),
119 }
120 }
121
122 pub fn local_addr(&self) -> SocketAddr {
123 self.src_addr
124 }
125
126 pub fn peer_addr(&self) -> SocketAddr {
127 self.dst_addr
128 }
129}