1use crate::{
2 packet::{IpHeader, NetworkPacket, TransportHeader},
3 PacketReceiver, PacketSender, TTL,
4};
5use anyhow::Context;
6
7use bytes::Bytes;
8use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel, Ipv6Header, UdpHeader};
9
10use smol_timeout::TimeoutExt;
11use std::{net::SocketAddr, pin::Pin, time::Duration};
12
13#[derive(Debug)]
14pub struct IpStackUdpStream {
15 src_addr: SocketAddr,
16 dst_addr: SocketAddr,
17 stream_sender: PacketSender,
18 stream_receiver: Pin<Box<PacketReceiver>>,
19 pkt_sender: PacketSender,
20
21 udp_timeout: Duration,
22 mtu: u16,
23}
24
25impl IpStackUdpStream {
26 pub fn new(
27 src_addr: SocketAddr,
28 dst_addr: SocketAddr,
29
30 pkt_sender: PacketSender,
31 mtu: u16,
32 udp_timeout: Duration,
33 ) -> Self {
34 let (stream_sender, stream_receiver) = async_channel::bounded::<NetworkPacket>(10);
35
36 IpStackUdpStream {
37 src_addr,
38 dst_addr,
39 stream_sender,
40 stream_receiver: Box::pin(stream_receiver),
41 pkt_sender,
42
43 udp_timeout,
44 mtu,
45 }
46 }
47
48 pub async fn recv(&self) -> anyhow::Result<Bytes> {
49 Ok(self
50 .stream_receiver
51 .recv()
52 .timeout(self.udp_timeout)
53 .await
54 .context("timeout")??
55 .payload
56 .into())
57 }
58
59 pub async fn send(&self, bts: &[u8]) -> anyhow::Result<()> {
60 let packet = self
61 .create_rev_packet(TTL, bts.to_vec())
62 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
63 self.pkt_sender.send(packet).await?;
64 Ok(())
65 }
66
67 pub(crate) fn stream_sender(&self) -> PacketSender {
68 self.stream_sender.clone()
69 }
70
71 fn create_rev_packet(&self, ttl: u8, payload: Vec<u8>) -> anyhow::Result<NetworkPacket> {
72 const UHS: usize = 8; match (self.dst_addr.ip(), self.src_addr.ip()) {
74 (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => {
75 let mut ip_h = Ipv4Header::new(0, ttl, IpNumber::UDP, dst.octets(), src.octets())?;
76 if self.mtu < (ip_h.header_len() + UHS) as u16 {
77 anyhow::bail!("message too large");
78 }
79 let line_buffer = self.mtu.saturating_sub((ip_h.header_len() + UHS) as u16);
80 if payload.len() > line_buffer as usize {
81 anyhow::bail!("message too large");
82 }
83 ip_h.set_payload_len(payload.len() + UHS)?;
84 let udp_header = UdpHeader::with_ipv4_checksum(
85 self.dst_addr.port(),
86 self.src_addr.port(),
87 &ip_h,
88 &payload,
89 )?;
90 Ok(NetworkPacket {
91 ip: IpHeader::Ipv4(ip_h),
92 transport: TransportHeader::Udp(udp_header),
93 payload,
94 })
95 }
96 (std::net::IpAddr::V6(dst), std::net::IpAddr::V6(src)) => {
97 let mut ip_h = Ipv6Header {
98 traffic_class: 0,
99 flow_label: Ipv6FlowLabel::ZERO,
100 payload_length: 0,
101 next_header: IpNumber::UDP,
102 hop_limit: ttl,
103 source: dst.octets(),
104 destination: src.octets(),
105 };
106 if self.mtu < (ip_h.header_len() + UHS) as u16 {
107 anyhow::bail!("message too large");
108 }
109 let line_buffer = self.mtu.saturating_sub((ip_h.header_len() + UHS) as u16);
110
111 if payload.len() > line_buffer as usize {
112 anyhow::bail!("message too large");
113 }
114
115 ip_h.payload_length = (payload.len() + UHS) as u16;
116 let udp_header = UdpHeader::with_ipv6_checksum(
117 self.dst_addr.port(),
118 self.src_addr.port(),
119 &ip_h,
120 &payload,
121 )?;
122 Ok(NetworkPacket {
123 ip: IpHeader::Ipv6(ip_h),
124 transport: TransportHeader::Udp(udp_header),
125 payload,
126 })
127 }
128 _ => unreachable!(),
129 }
130 }
131
132 pub fn local_addr(&self) -> SocketAddr {
133 self.src_addr
134 }
135
136 pub fn peer_addr(&self) -> SocketAddr {
137 self.dst_addr
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144 use std::{
145 net::{IpAddr, Ipv4Addr, Ipv6Addr},
146 time::Duration,
147 };
148
149 #[test]
150 fn oversized_udp_datagram_returns_error() {
151 let (sender, _receiver) = async_channel::unbounded();
152 let stream = IpStackUdpStream::new(
153 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 1000),
154 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)), 2000),
155 sender,
156 28,
157 Duration::from_secs(60),
158 );
159
160 assert!(stream.create_rev_packet(TTL, vec![1]).is_err());
161 assert!(stream.create_rev_packet(TTL, Vec::new()).is_ok());
162 }
163
164 #[test]
165 fn oversized_ipv6_udp_datagram_returns_error() {
166 let (sender, _receiver) = async_channel::unbounded();
167 let stream = IpStackUdpStream::new(
168 SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 1000),
169 SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 2000),
170 sender,
171 48,
172 Duration::from_secs(60),
173 );
174
175 assert!(stream.create_rev_packet(TTL, vec![1]).is_err());
176 assert!(stream.create_rev_packet(TTL, Vec::new()).is_ok());
177 }
178}