interproto/protocols/
udp.rs1use std::net::{Ipv4Addr, Ipv6Addr};
2
3use pnet::packet::udp::{self, MutableUdpPacket, UdpPacket};
4
5use crate::error::{Error, Result};
6
7#[profiling::function]
9pub fn recalculate_udp_checksum_ipv6(
10 udp_packet: &[u8],
11 new_source: Ipv6Addr,
12 new_destination: Ipv6Addr,
13) -> Result<Vec<u8>> {
14 {
16 let mut udp_packet_buffer = udp_packet.to_vec();
18
19 let mut udp_packet =
21 MutableUdpPacket::new(&mut udp_packet_buffer).ok_or(Error::PacketTooShort {
22 expected: UdpPacket::minimum_packet_size(),
23 actual: udp_packet.len(),
24 })?;
25
26 udp_packet.set_checksum(0);
28 udp_packet.set_checksum(udp::ipv6_checksum(
29 &udp_packet.to_immutable(),
30 &new_source,
31 &new_destination,
32 ));
33
34 #[cfg(feature = "metrics")]
36 protomask_metrics::metric!(PACKET_COUNTER, PROTOCOL_UDP, STATUS_TRANSLATED).inc();
37
38 Ok(udp_packet_buffer)
40 }
41 .map_err(|error| {
42 #[cfg(feature = "metrics")]
44 protomask_metrics::metric!(PACKET_COUNTER, PROTOCOL_UDP, STATUS_DROPPED).inc();
45
46 error
48 })
49}
50
51#[profiling::function]
53pub fn recalculate_udp_checksum_ipv4(
54 udp_packet: &[u8],
55 new_source: Ipv4Addr,
56 new_destination: Ipv4Addr,
57) -> Result<Vec<u8>> {
58 {
60 let mut udp_packet_buffer = udp_packet.to_vec();
62
63 let mut udp_packet =
65 MutableUdpPacket::new(&mut udp_packet_buffer).ok_or(Error::PacketTooShort {
66 expected: UdpPacket::minimum_packet_size(),
67 actual: udp_packet.len(),
68 })?;
69
70 udp_packet.set_checksum(0);
72 udp_packet.set_checksum(udp::ipv4_checksum(
73 &udp_packet.to_immutable(),
74 &new_source,
75 &new_destination,
76 ));
77
78 #[cfg(feature = "metrics")]
80 protomask_metrics::metric!(PACKET_COUNTER, PROTOCOL_UDP, STATUS_TRANSLATED).inc();
81
82 Ok(udp_packet_buffer)
84 }
85 .map_err(|error| {
86 #[cfg(feature = "metrics")]
88 protomask_metrics::metric!(PACKET_COUNTER, PROTOCOL_UDP, STATUS_DROPPED).inc();
89
90 error
92 })
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98
99 #[test]
100 fn test_recalculate_udp_checksum_ipv6() {
101 let mut input_buffer = vec![0u8; UdpPacket::minimum_packet_size() + 13];
102 let mut udp_packet = MutableUdpPacket::new(&mut input_buffer).unwrap();
103 udp_packet.set_source(1234);
104 udp_packet.set_destination(5678);
105 udp_packet.set_length(13);
106 udp_packet.set_payload(&"Hello, world!".as_bytes().to_vec());
107
108 let recalculated_buffer = recalculate_udp_checksum_ipv6(
110 &input_buffer,
111 "2001:db8::1".parse().unwrap(),
112 "2001:db8::2".parse().unwrap(),
113 )
114 .unwrap();
115
116 let recalculated_packet = UdpPacket::new(&recalculated_buffer).unwrap();
118 assert_eq!(recalculated_packet.get_checksum(), 0x480b);
119 }
120
121 #[test]
122 fn test_recalculate_udp_checksum_ipv4() {
123 let mut input_buffer = vec![0u8; UdpPacket::minimum_packet_size() + 13];
124 let mut udp_packet = MutableUdpPacket::new(&mut input_buffer).unwrap();
125 udp_packet.set_source(1234);
126 udp_packet.set_destination(5678);
127 udp_packet.set_length(13);
128 udp_packet.set_payload(&"Hello, world!".as_bytes().to_vec());
129
130 let recalculated_buffer = recalculate_udp_checksum_ipv4(
132 &input_buffer,
133 "192.0.2.1".parse().unwrap(),
134 "192.0.2.2".parse().unwrap(),
135 )
136 .unwrap();
137
138 let recalculated_packet = UdpPacket::new(&recalculated_buffer).unwrap();
140 assert_eq!(recalculated_packet.get_checksum(), 0x1f7c);
141 }
142}