interproto/protocols/
udp.rs

1use std::net::{Ipv4Addr, Ipv6Addr};
2
3use pnet::packet::udp::{self, MutableUdpPacket, UdpPacket};
4
5use crate::error::{Error, Result};
6
7/// Re-calculates a UDP packet's checksum with a new IPv6 pseudo-header.
8#[profiling::function]
9pub fn recalculate_udp_checksum_ipv6(
10    udp_packet: &[u8],
11    new_source: Ipv6Addr,
12    new_destination: Ipv6Addr,
13) -> Result<Vec<u8>> {
14    // This scope is used to collect packet drop metrics
15    {
16        // Clone the packet so we can modify it
17        let mut udp_packet_buffer = udp_packet.to_vec();
18
19        // Get safe mutable access to the packet
20        let mut udp_packet =
21            MutableUdpPacket::new(&mut udp_packet_buffer).ok_or(Error::PacketTooShort {
22                expected: UdpPacket::minimum_packet_size(),
23                actual: udp_packet.len(),
24            })?;
25
26        // Edit the packet's checksum
27        udp_packet.set_checksum(0);
28        udp_packet.set_checksum(udp::ipv6_checksum(
29            &udp_packet.to_immutable(),
30            &new_source,
31            &new_destination,
32        ));
33
34        // Track the translated packet
35        #[cfg(feature = "metrics")]
36        protomask_metrics::metric!(PACKET_COUNTER, PROTOCOL_UDP, STATUS_TRANSLATED).inc();
37
38        // Return the translated packet
39        Ok(udp_packet_buffer)
40    }
41    .map_err(|error| {
42        // Track the dropped packet
43        #[cfg(feature = "metrics")]
44        protomask_metrics::metric!(PACKET_COUNTER, PROTOCOL_UDP, STATUS_DROPPED).inc();
45
46        // Pass the error through
47        error
48    })
49}
50
51/// Re-calculates a UDP packet's checksum with a new IPv4 pseudo-header.
52#[profiling::function]
53pub fn recalculate_udp_checksum_ipv4(
54    udp_packet: &[u8],
55    new_source: Ipv4Addr,
56    new_destination: Ipv4Addr,
57) -> Result<Vec<u8>> {
58    // This scope is used to collect packet drop metrics
59    {
60        // Clone the packet so we can modify it
61        let mut udp_packet_buffer = udp_packet.to_vec();
62
63        // Get safe mutable access to the packet
64        let mut udp_packet =
65            MutableUdpPacket::new(&mut udp_packet_buffer).ok_or(Error::PacketTooShort {
66                expected: UdpPacket::minimum_packet_size(),
67                actual: udp_packet.len(),
68            })?;
69
70        // Edit the packet's checksum
71        udp_packet.set_checksum(0);
72        udp_packet.set_checksum(udp::ipv4_checksum(
73            &udp_packet.to_immutable(),
74            &new_source,
75            &new_destination,
76        ));
77
78        // Track the translated packet
79        #[cfg(feature = "metrics")]
80        protomask_metrics::metric!(PACKET_COUNTER, PROTOCOL_UDP, STATUS_TRANSLATED).inc();
81
82        // Return the translated packet
83        Ok(udp_packet_buffer)
84    }
85    .map_err(|error| {
86        // Track the dropped packet
87        #[cfg(feature = "metrics")]
88        protomask_metrics::metric!(PACKET_COUNTER, PROTOCOL_UDP, STATUS_DROPPED).inc();
89
90        // Pass the error through
91        error
92    })
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98
99    #[test]
100    fn test_recalculate_udp_checksum_ipv6() {
101        let mut input_buffer = vec![0u8; UdpPacket::minimum_packet_size() + 13];
102        let mut udp_packet = MutableUdpPacket::new(&mut input_buffer).unwrap();
103        udp_packet.set_source(1234);
104        udp_packet.set_destination(5678);
105        udp_packet.set_length(13);
106        udp_packet.set_payload(&"Hello, world!".as_bytes().to_vec());
107
108        // Recalculate the checksum
109        let recalculated_buffer = recalculate_udp_checksum_ipv6(
110            &input_buffer,
111            "2001:db8::1".parse().unwrap(),
112            "2001:db8::2".parse().unwrap(),
113        )
114        .unwrap();
115
116        // Check that the checksum is correct
117        let recalculated_packet = UdpPacket::new(&recalculated_buffer).unwrap();
118        assert_eq!(recalculated_packet.get_checksum(), 0x480b);
119    }
120
121    #[test]
122    fn test_recalculate_udp_checksum_ipv4() {
123        let mut input_buffer = vec![0u8; UdpPacket::minimum_packet_size() + 13];
124        let mut udp_packet = MutableUdpPacket::new(&mut input_buffer).unwrap();
125        udp_packet.set_source(1234);
126        udp_packet.set_destination(5678);
127        udp_packet.set_length(13);
128        udp_packet.set_payload(&"Hello, world!".as_bytes().to_vec());
129
130        // Recalculate the checksum
131        let recalculated_buffer = recalculate_udp_checksum_ipv4(
132            &input_buffer,
133            "192.0.2.1".parse().unwrap(),
134            "192.0.2.2".parse().unwrap(),
135        )
136        .unwrap();
137
138        // Check that the checksum is correct
139        let recalculated_packet = UdpPacket::new(&recalculated_buffer).unwrap();
140        assert_eq!(recalculated_packet.get_checksum(), 0x1f7c);
141    }
142}