1use crate::ip::IpNextProtocol;
4use nex_core::bitfield::u16be;
5
6use core::convert::TryInto;
7use core::u16;
8use core::u8;
9use std::net::{Ipv4Addr, Ipv6Addr};
10
11pub trait Octets {
13 type Output;
15
16 fn octets(&self) -> Self::Output;
18}
19
20impl Octets for u64 {
21 type Output = [u8; 8];
22
23 fn octets(&self) -> Self::Output {
24 [
25 (*self >> 56) as u8,
26 (*self >> 48) as u8,
27 (*self >> 40) as u8,
28 (*self >> 32) as u8,
29 (*self >> 24) as u8,
30 (*self >> 16) as u8,
31 (*self >> 8) as u8,
32 *self as u8,
33 ]
34 }
35}
36
37impl Octets for u32 {
38 type Output = [u8; 4];
39
40 fn octets(&self) -> Self::Output {
41 [
42 (*self >> 24) as u8,
43 (*self >> 16) as u8,
44 (*self >> 8) as u8,
45 *self as u8,
46 ]
47 }
48}
49
50impl Octets for u16 {
51 type Output = [u8; 2];
52
53 fn octets(&self) -> Self::Output {
54 [(*self >> 8) as u8, *self as u8]
55 }
56}
57
58impl Octets for u8 {
59 type Output = [u8; 1];
60
61 fn octets(&self) -> Self::Output {
62 [*self]
63 }
64}
65
66pub fn checksum(data: &[u8], skipword: usize) -> u16be {
69 if data.len() == 0 {
70 return 0;
71 }
72 let sum = sum_be_words(data, skipword);
73 finalize_checksum(sum)
74}
75
76fn finalize_checksum(mut sum: u32) -> u16be {
77 while sum >> 16 != 0 {
78 sum = (sum >> 16) + (sum & 0xFFFF);
79 }
80 !sum as u16
81}
82
83pub fn ipv4_checksum(
85 data: &[u8],
86 skipword: usize,
87 extra_data: &[u8],
88 source: &Ipv4Addr,
89 destination: &Ipv4Addr,
90 next_level_protocol: IpNextProtocol,
91) -> u16be {
92 let mut sum = 0u32;
93
94 sum += ipv4_word_sum(source);
96 sum += ipv4_word_sum(destination);
97 sum += next_level_protocol as u32;
98
99 let len = data.len() + extra_data.len();
100 sum += len as u32;
101
102 sum += sum_be_words(data, skipword);
104 sum += sum_be_words(extra_data, extra_data.len() / 2);
105
106 finalize_checksum(sum)
107}
108
109fn ipv4_word_sum(ip: &Ipv4Addr) -> u32 {
110 let octets = ip.octets();
111 ((octets[0] as u32) << 8 | octets[1] as u32) + ((octets[2] as u32) << 8 | octets[3] as u32)
112}
113
114pub fn ipv6_checksum(
116 data: &[u8],
117 skipword: usize,
118 extra_data: &[u8],
119 source: &Ipv6Addr,
120 destination: &Ipv6Addr,
121 next_level_protocol: IpNextProtocol,
122) -> u16be {
123 let mut sum = 0u32;
124
125 sum += ipv6_word_sum(source);
127 sum += ipv6_word_sum(destination);
128 sum += next_level_protocol as u32;
129
130 let len = data.len() + extra_data.len();
131 sum += len as u32;
132
133 sum += sum_be_words(data, skipword);
135 sum += sum_be_words(extra_data, extra_data.len() / 2);
136
137 finalize_checksum(sum)
138}
139
140fn ipv6_word_sum(ip: &Ipv6Addr) -> u32 {
141 ip.segments().iter().map(|x| *x as u32).sum()
142}
143
144fn sum_be_words(data: &[u8], skipword: usize) -> u32 {
147 if data.len() == 0 {
148 return 0;
149 }
150 let len = data.len();
151 let mut cur_data = &data[..];
152 let mut sum = 0u32;
153 let mut i = 0;
154 while cur_data.len() >= 2 {
155 if i != skipword {
156 sum += u16::from_be_bytes(cur_data[0..2].try_into().unwrap()) as u32;
158 }
159 cur_data = &cur_data[2..];
160 i += 1;
161 }
162
163 if i != skipword && len & 1 != 0 {
165 sum += (data[len - 1] as u32) << 8;
166 }
167
168 sum
169}
170
171#[cfg(test)]
172mod tests {
173 use super::sum_be_words;
174 use core::slice;
175
176 #[test]
177 fn sum_be_words_different_skipwords() {
178 let data = (0..11).collect::<Vec<u8>>();
179 assert_eq!(7190, sum_be_words(&data, 1));
180 assert_eq!(6676, sum_be_words(&data, 2));
181 assert_eq!(7705, sum_be_words(&data, 99));
184 assert_eq!(7705, sum_be_words(&data, 101));
185 }
186
187 #[test]
188 fn sum_be_words_small_sizes() {
189 let data_zero = vec![0; 0];
190 assert_eq!(0, sum_be_words(&data_zero, 0));
191 assert_eq!(0, sum_be_words(&data_zero, 10));
192 let data_one = vec![1; 1];
193 assert_eq!(0, sum_be_words(&data_zero, 0));
194 assert_eq!(256, sum_be_words(&data_one, 1));
195 let data_two = vec![1; 2];
196 assert_eq!(0, sum_be_words(&data_two, 0));
197 assert_eq!(257, sum_be_words(&data_two, 1));
198 let data_three = vec![4; 3];
199 assert_eq!(1024, sum_be_words(&data_three, 0));
200 assert_eq!(1028, sum_be_words(&data_three, 1));
201 assert_eq!(2052, sum_be_words(&data_three, 2));
202 assert_eq!(2052, sum_be_words(&data_three, 3));
203 }
204
205 #[test]
206 fn sum_be_words_misaligned_ptr() {
207 let mut data = vec![0; 13];
208 let ptr = match data.as_ptr() as usize % 2 {
209 0 => unsafe { data.as_mut_ptr().offset(1) },
210 _ => data.as_mut_ptr(),
211 };
212 unsafe {
213 let slice_data = slice::from_raw_parts_mut(ptr, 12);
214 for i in 0..11 {
215 slice_data[i] = i as u8;
216 }
217 assert_eq!(7190, sum_be_words(&slice_data, 1));
218 assert_eq!(6676, sum_be_words(&slice_data, 2));
219 assert_eq!(7705, sum_be_words(&slice_data, 99));
222 assert_eq!(7705, sum_be_words(&slice_data, 101));
223 }
224 }
225}