1use crate::{
2 packet::{IpHeader, NetworkPacket, TransportHeader},
3 PacketSender, TTL,
4};
5use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel, Ipv6Header};
6use std::{
7 io::{Error, ErrorKind},
8 mem,
9 net::IpAddr,
10};
11
12pub struct IpStackUnknownTransport {
13 src_addr: IpAddr,
14 dst_addr: IpAddr,
15 payload: Vec<u8>,
16 protocol: IpNumber,
17 mtu: u16,
18 packet_sender: PacketSender,
19}
20
21impl IpStackUnknownTransport {
22 pub(crate) fn new(
23 src_addr: IpAddr,
24 dst_addr: IpAddr,
25 payload: Vec<u8>,
26 ip: &IpHeader,
27 mtu: u16,
28 packet_sender: PacketSender,
29 ) -> Self {
30 let protocol = match ip {
31 IpHeader::Ipv4(ip) => ip.protocol,
32 IpHeader::Ipv6(ip) => ip.next_header,
33 };
34 IpStackUnknownTransport {
35 src_addr,
36 dst_addr,
37 payload,
38 protocol,
39 mtu,
40 packet_sender,
41 }
42 }
43 pub fn src_addr(&self) -> IpAddr {
44 self.src_addr
45 }
46 pub fn dst_addr(&self) -> IpAddr {
47 self.dst_addr
48 }
49 pub fn payload(&self) -> &[u8] {
50 &self.payload
51 }
52 pub fn ip_protocol(&self) -> IpNumber {
53 self.protocol
54 }
55 pub fn send(&self, mut payload: Vec<u8>) -> Result<(), Error> {
56 loop {
57 let packet = self
58 .create_rev_packet(&mut payload)
59 .map_err(|e| std::io::Error::new(ErrorKind::InvalidData, e))?;
60 self.packet_sender
61 .try_send(packet)
62 .map_err(|_| Error::other("send error"))?;
63 if payload.is_empty() {
64 return Ok(());
65 }
66 }
67 }
68
69 pub fn create_rev_packet(&self, payload: &mut Vec<u8>) -> anyhow::Result<NetworkPacket> {
70 match (self.dst_addr, self.src_addr) {
71 (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => {
72 let mut ip_h = Ipv4Header::new(0, TTL, self.protocol, dst.octets(), src.octets())?;
73 let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16);
74 if line_buffer == 0 && !payload.is_empty() {
75 anyhow::bail!("message too large");
76 }
77
78 let p = if payload.len() > line_buffer as usize {
79 payload.drain(0..line_buffer as usize).collect::<Vec<u8>>()
80 } else {
81 mem::take(payload)
82 };
83 ip_h.set_payload_len(p.len())?;
84 Ok(NetworkPacket {
85 ip: IpHeader::Ipv4(ip_h),
86 transport: TransportHeader::Unknown,
87 payload: p,
88 })
89 }
90 (std::net::IpAddr::V6(dst), std::net::IpAddr::V6(src)) => {
91 let mut ip_h = Ipv6Header {
92 traffic_class: 0,
93 flow_label: Ipv6FlowLabel::ZERO,
94 payload_length: 0,
95 next_header: self.protocol,
96 hop_limit: TTL,
97 source: dst.octets(),
98 destination: src.octets(),
99 };
100 let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16);
101 if line_buffer == 0 && !payload.is_empty() {
102 anyhow::bail!("message too large");
103 }
104 let p = if payload.len() > line_buffer as usize {
105 payload.drain(0..line_buffer as usize).collect::<Vec<u8>>()
106 } else {
107 mem::take(payload)
108 };
109 ip_h.payload_length = p.len() as u16;
110 Ok(NetworkPacket {
111 ip: IpHeader::Ipv6(ip_h),
112 transport: TransportHeader::Unknown,
113 payload: p,
114 })
115 }
116 _ => unreachable!(),
117 }
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use std::net::Ipv6Addr;
125
126 #[test]
127 fn ipv6_unknown_transport_preserves_protocol_and_splits_payload() {
128 let (packet_sender, _packet_receiver) = async_channel::unbounded();
129 let transport = IpStackUnknownTransport {
130 src_addr: IpAddr::V6(Ipv6Addr::LOCALHOST),
131 dst_addr: IpAddr::V6(Ipv6Addr::LOCALHOST),
132 payload: Vec::new(),
133 protocol: IpNumber::IPV6_ICMP,
134 mtu: 48,
135 packet_sender,
136 };
137 let mut payload = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
138
139 let packet = transport.create_rev_packet(&mut payload).unwrap();
140 let IpHeader::Ipv6(ip) = packet.ip else {
141 panic!("expected IPv6");
142 };
143 assert_eq!(ip.next_header, IpNumber::IPV6_ICMP);
144 assert_eq!(packet.payload, vec![1, 2, 3, 4, 5, 6, 7, 8]);
145 assert_eq!(payload, vec![9]);
146 }
147
148 #[test]
149 fn unknown_transport_rejects_non_empty_payload_when_mtu_cannot_fit_data() {
150 let (packet_sender, _packet_receiver) = async_channel::unbounded();
151 let transport = IpStackUnknownTransport {
152 src_addr: IpAddr::V6(Ipv6Addr::LOCALHOST),
153 dst_addr: IpAddr::V6(Ipv6Addr::LOCALHOST),
154 payload: Vec::new(),
155 protocol: IpNumber::IPV6_ICMP,
156 mtu: 40,
157 packet_sender,
158 };
159 let mut payload = vec![1];
160
161 assert!(transport.create_rev_packet(&mut payload).is_err());
162 assert_eq!(payload, vec![1]);
163 }
164}