Skip to main content

nex_packet/
util.rs

1//! Utilities for working with packets, eg. checksumming.
2
3use crate::ip::IpNextProtocol;
4use nex_core::bitfield::u16be;
5
6use core::u8;
7use core::u16;
8use std::net::{Ipv4Addr, Ipv6Addr};
9
10/// Convert a value to a byte array.
11pub trait Octets {
12    /// Output type - bytes array.
13    type Output;
14
15    /// Return a value as bytes (big-endian order).
16    fn octets(&self) -> Self::Output;
17}
18
19impl Octets for u64 {
20    type Output = [u8; 8];
21
22    fn octets(&self) -> Self::Output {
23        [
24            (*self >> 56) as u8,
25            (*self >> 48) as u8,
26            (*self >> 40) as u8,
27            (*self >> 32) as u8,
28            (*self >> 24) as u8,
29            (*self >> 16) as u8,
30            (*self >> 8) as u8,
31            *self as u8,
32        ]
33    }
34}
35
36impl Octets for u32 {
37    type Output = [u8; 4];
38
39    fn octets(&self) -> Self::Output {
40        [
41            (*self >> 24) as u8,
42            (*self >> 16) as u8,
43            (*self >> 8) as u8,
44            *self as u8,
45        ]
46    }
47}
48
49impl Octets for u16 {
50    type Output = [u8; 2];
51
52    fn octets(&self) -> Self::Output {
53        [(*self >> 8) as u8, *self as u8]
54    }
55}
56
57impl Octets for u8 {
58    type Output = [u8; 1];
59
60    fn octets(&self) -> Self::Output {
61        [*self]
62    }
63}
64
65/// Calculates a checksum. Used by ipv4 and icmp. The two bytes starting at `skipword * 2` will be
66/// ignored. Supposed to be the checksum field, which is regarded as zero during calculation.
67pub fn checksum(data: &[u8], skipword: usize) -> u16be {
68    if data.len() == 0 {
69        return 0;
70    }
71    let sum = sum_be_words(data, skipword);
72    finalize_checksum(sum)
73}
74
75fn finalize_checksum(mut sum: u32) -> u16be {
76    while sum >> 16 != 0 {
77        sum = (sum >> 16) + (sum & 0xFFFF);
78    }
79    !sum as u16
80}
81
82/// Calculate the checksum for a packet built on IPv4. Used by UDP and TCP.
83pub fn ipv4_checksum(
84    data: &[u8],
85    skipword: usize,
86    extra_data: &[u8],
87    source: &Ipv4Addr,
88    destination: &Ipv4Addr,
89    next_level_protocol: IpNextProtocol,
90) -> u16be {
91    let mut sum = 0u32;
92
93    // Checksum pseudo-header
94    sum += ipv4_word_sum(source);
95    sum += ipv4_word_sum(destination);
96    sum += next_level_protocol as u32;
97
98    let len = data.len() + extra_data.len();
99    sum += len as u32;
100
101    // Checksum packet header and data
102    sum += sum_be_words(data, skipword);
103    sum += sum_be_words(extra_data, extra_data.len() / 2);
104
105    finalize_checksum(sum)
106}
107
108fn ipv4_word_sum(ip: &Ipv4Addr) -> u32 {
109    let octets = ip.octets();
110    ((octets[0] as u32) << 8 | octets[1] as u32) + ((octets[2] as u32) << 8 | octets[3] as u32)
111}
112
113/// Calculate the checksum for a packet built on IPv6.
114pub fn ipv6_checksum(
115    data: &[u8],
116    skipword: usize,
117    extra_data: &[u8],
118    source: &Ipv6Addr,
119    destination: &Ipv6Addr,
120    next_level_protocol: IpNextProtocol,
121) -> u16be {
122    let mut sum = 0u32;
123
124    // Checksum pseudo-header
125    sum += ipv6_word_sum(source);
126    sum += ipv6_word_sum(destination);
127    sum += next_level_protocol as u32;
128
129    let len = data.len() + extra_data.len();
130    sum += len as u32;
131
132    // Checksum packet header and data
133    sum += sum_be_words(data, skipword);
134    sum += sum_be_words(extra_data, extra_data.len() / 2);
135
136    finalize_checksum(sum)
137}
138
139fn ipv6_word_sum(ip: &Ipv6Addr) -> u32 {
140    ip.segments().iter().map(|x| *x as u32).sum()
141}
142
143/// Sum all words (16 bit chunks) in the given data. The word at word offset
144/// `skipword` will be skipped. Each word is treated as big endian.
145fn sum_be_words(data: &[u8], skipword: usize) -> u32 {
146    if data.len() == 0 {
147        return 0;
148    }
149    let len = data.len();
150    let mut cur_data = &data[..];
151    let mut sum = 0u32;
152    let mut i = 0;
153    while cur_data.len() >= 2 {
154        if i != skipword {
155            sum += ((cur_data[0] as u32) << 8) + cur_data[1] as u32;
156        }
157        cur_data = &cur_data[2..];
158        i += 1;
159    }
160
161    // If the length is odd, make sure to checksum the final byte
162    if i != skipword && len & 1 != 0 {
163        sum += (data[len - 1] as u32) << 8;
164    }
165
166    sum
167}
168
169#[cfg(test)]
170mod tests {
171    use super::sum_be_words;
172    use core::slice;
173
174    #[test]
175    fn sum_be_words_different_skipwords() {
176        let data = (0..11).collect::<Vec<u8>>();
177        assert_eq!(7190, sum_be_words(&data, 1));
178        assert_eq!(6676, sum_be_words(&data, 2));
179        // Assert having the skipword outside the range gives correct and equal
180        // results
181        assert_eq!(7705, sum_be_words(&data, 99));
182        assert_eq!(7705, sum_be_words(&data, 101));
183    }
184
185    #[test]
186    fn sum_be_words_small_sizes() {
187        let data_zero = vec![0; 0];
188        assert_eq!(0, sum_be_words(&data_zero, 0));
189        assert_eq!(0, sum_be_words(&data_zero, 10));
190        let data_one = vec![1; 1];
191        assert_eq!(0, sum_be_words(&data_zero, 0));
192        assert_eq!(256, sum_be_words(&data_one, 1));
193        let data_two = vec![1; 2];
194        assert_eq!(0, sum_be_words(&data_two, 0));
195        assert_eq!(257, sum_be_words(&data_two, 1));
196        let data_three = vec![4; 3];
197        assert_eq!(1024, sum_be_words(&data_three, 0));
198        assert_eq!(1028, sum_be_words(&data_three, 1));
199        assert_eq!(2052, sum_be_words(&data_three, 2));
200        assert_eq!(2052, sum_be_words(&data_three, 3));
201    }
202
203    #[test]
204    fn sum_be_words_misaligned_ptr() {
205        let mut data = vec![0; 13];
206        let ptr = match data.as_ptr() as usize % 2 {
207            0 => unsafe { data.as_mut_ptr().offset(1) },
208            _ => data.as_mut_ptr(),
209        };
210        unsafe {
211            let slice_data = slice::from_raw_parts_mut(ptr, 12);
212            for i in 0..11 {
213                slice_data[i] = i as u8;
214            }
215            assert_eq!(7190, sum_be_words(&slice_data, 1));
216            assert_eq!(6676, sum_be_words(&slice_data, 2));
217            // Assert having the skipword outside the range gives correct and equal
218            // results
219            assert_eq!(7705, sum_be_words(&slice_data, 99));
220            assert_eq!(7705, sum_be_words(&slice_data, 101));
221        }
222    }
223}