Skip to main content

ipstack_geph/stream/
unknown.rs

1use crate::{
2    packet::{IpHeader, NetworkPacket, TransportHeader},
3    PacketSender, TTL,
4};
5use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel, Ipv6Header};
6use std::{
7    io::{Error, ErrorKind},
8    mem,
9    net::IpAddr,
10};
11
12pub struct IpStackUnknownTransport {
13    src_addr: IpAddr,
14    dst_addr: IpAddr,
15    payload: Vec<u8>,
16    protocol: IpNumber,
17    mtu: u16,
18    packet_sender: PacketSender,
19}
20
21impl IpStackUnknownTransport {
22    pub(crate) fn new(
23        src_addr: IpAddr,
24        dst_addr: IpAddr,
25        payload: Vec<u8>,
26        ip: &IpHeader,
27        mtu: u16,
28        packet_sender: PacketSender,
29    ) -> Self {
30        let protocol = match ip {
31            IpHeader::Ipv4(ip) => ip.protocol,
32            IpHeader::Ipv6(ip) => ip.next_header,
33        };
34        IpStackUnknownTransport {
35            src_addr,
36            dst_addr,
37            payload,
38            protocol,
39            mtu,
40            packet_sender,
41        }
42    }
43    pub fn src_addr(&self) -> IpAddr {
44        self.src_addr
45    }
46    pub fn dst_addr(&self) -> IpAddr {
47        self.dst_addr
48    }
49    pub fn payload(&self) -> &[u8] {
50        &self.payload
51    }
52    pub fn ip_protocol(&self) -> IpNumber {
53        self.protocol
54    }
55    pub fn send(&self, mut payload: Vec<u8>) -> Result<(), Error> {
56        loop {
57            let packet = self
58                .create_rev_packet(&mut payload)
59                .map_err(|e| std::io::Error::new(ErrorKind::InvalidData, e))?;
60            self.packet_sender
61                .try_send(packet)
62                .map_err(|_| Error::other("send error"))?;
63            if payload.is_empty() {
64                return Ok(());
65            }
66        }
67    }
68
69    pub fn create_rev_packet(&self, payload: &mut Vec<u8>) -> anyhow::Result<NetworkPacket> {
70        match (self.dst_addr, self.src_addr) {
71            (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => {
72                let mut ip_h = Ipv4Header::new(0, TTL, self.protocol, dst.octets(), src.octets())?;
73                let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16);
74                if line_buffer == 0 && !payload.is_empty() {
75                    anyhow::bail!("message too large");
76                }
77
78                let p = if payload.len() > line_buffer as usize {
79                    payload.drain(0..line_buffer as usize).collect::<Vec<u8>>()
80                } else {
81                    mem::take(payload)
82                };
83                ip_h.set_payload_len(p.len())?;
84                Ok(NetworkPacket {
85                    ip: IpHeader::Ipv4(ip_h),
86                    transport: TransportHeader::Unknown,
87                    payload: p,
88                })
89            }
90            (std::net::IpAddr::V6(dst), std::net::IpAddr::V6(src)) => {
91                let mut ip_h = Ipv6Header {
92                    traffic_class: 0,
93                    flow_label: Ipv6FlowLabel::ZERO,
94                    payload_length: 0,
95                    next_header: self.protocol,
96                    hop_limit: TTL,
97                    source: dst.octets(),
98                    destination: src.octets(),
99                };
100                let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16);
101                if line_buffer == 0 && !payload.is_empty() {
102                    anyhow::bail!("message too large");
103                }
104                let p = if payload.len() > line_buffer as usize {
105                    payload.drain(0..line_buffer as usize).collect::<Vec<u8>>()
106                } else {
107                    mem::take(payload)
108                };
109                ip_h.payload_length = p.len() as u16;
110                Ok(NetworkPacket {
111                    ip: IpHeader::Ipv6(ip_h),
112                    transport: TransportHeader::Unknown,
113                    payload: p,
114                })
115            }
116            _ => unreachable!(),
117        }
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124    use std::net::Ipv6Addr;
125
126    #[test]
127    fn ipv6_unknown_transport_preserves_protocol_and_splits_payload() {
128        let (packet_sender, _packet_receiver) = async_channel::unbounded();
129        let transport = IpStackUnknownTransport {
130            src_addr: IpAddr::V6(Ipv6Addr::LOCALHOST),
131            dst_addr: IpAddr::V6(Ipv6Addr::LOCALHOST),
132            payload: Vec::new(),
133            protocol: IpNumber::IPV6_ICMP,
134            mtu: 48,
135            packet_sender,
136        };
137        let mut payload = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
138
139        let packet = transport.create_rev_packet(&mut payload).unwrap();
140        let IpHeader::Ipv6(ip) = packet.ip else {
141            panic!("expected IPv6");
142        };
143        assert_eq!(ip.next_header, IpNumber::IPV6_ICMP);
144        assert_eq!(packet.payload, vec![1, 2, 3, 4, 5, 6, 7, 8]);
145        assert_eq!(payload, vec![9]);
146    }
147
148    #[test]
149    fn unknown_transport_rejects_non_empty_payload_when_mtu_cannot_fit_data() {
150        let (packet_sender, _packet_receiver) = async_channel::unbounded();
151        let transport = IpStackUnknownTransport {
152            src_addr: IpAddr::V6(Ipv6Addr::LOCALHOST),
153            dst_addr: IpAddr::V6(Ipv6Addr::LOCALHOST),
154            payload: Vec::new(),
155            protocol: IpNumber::IPV6_ICMP,
156            mtu: 40,
157            packet_sender,
158        };
159        let mut payload = vec![1];
160
161        assert!(transport.create_rev_packet(&mut payload).is_err());
162        assert_eq!(payload, vec![1]);
163    }
164}