1use crate::ip::IpNextProtocol;
4use nex_core::bitfield::u16be;
5
6use core::u8;
7use core::u16;
8use std::net::{Ipv4Addr, Ipv6Addr};
9
10pub trait Octets {
12 type Output;
14
15 fn octets(&self) -> Self::Output;
17}
18
19impl Octets for u64 {
20 type Output = [u8; 8];
21
22 fn octets(&self) -> Self::Output {
23 [
24 (*self >> 56) as u8,
25 (*self >> 48) as u8,
26 (*self >> 40) as u8,
27 (*self >> 32) as u8,
28 (*self >> 24) as u8,
29 (*self >> 16) as u8,
30 (*self >> 8) as u8,
31 *self as u8,
32 ]
33 }
34}
35
36impl Octets for u32 {
37 type Output = [u8; 4];
38
39 fn octets(&self) -> Self::Output {
40 [
41 (*self >> 24) as u8,
42 (*self >> 16) as u8,
43 (*self >> 8) as u8,
44 *self as u8,
45 ]
46 }
47}
48
49impl Octets for u16 {
50 type Output = [u8; 2];
51
52 fn octets(&self) -> Self::Output {
53 [(*self >> 8) as u8, *self as u8]
54 }
55}
56
57impl Octets for u8 {
58 type Output = [u8; 1];
59
60 fn octets(&self) -> Self::Output {
61 [*self]
62 }
63}
64
65pub fn checksum(data: &[u8], skipword: usize) -> u16be {
68 if data.len() == 0 {
69 return 0;
70 }
71 let sum = sum_be_words(data, skipword);
72 finalize_checksum(sum)
73}
74
75fn finalize_checksum(mut sum: u32) -> u16be {
76 while sum >> 16 != 0 {
77 sum = (sum >> 16) + (sum & 0xFFFF);
78 }
79 !sum as u16
80}
81
82pub fn ipv4_checksum(
84 data: &[u8],
85 skipword: usize,
86 extra_data: &[u8],
87 source: &Ipv4Addr,
88 destination: &Ipv4Addr,
89 next_level_protocol: IpNextProtocol,
90) -> u16be {
91 let mut sum = 0u32;
92
93 sum += ipv4_word_sum(source);
95 sum += ipv4_word_sum(destination);
96 sum += next_level_protocol as u32;
97
98 let len = data.len() + extra_data.len();
99 sum += len as u32;
100
101 sum += sum_be_words(data, skipword);
103 sum += sum_be_words(extra_data, extra_data.len() / 2);
104
105 finalize_checksum(sum)
106}
107
108fn ipv4_word_sum(ip: &Ipv4Addr) -> u32 {
109 let octets = ip.octets();
110 ((octets[0] as u32) << 8 | octets[1] as u32) + ((octets[2] as u32) << 8 | octets[3] as u32)
111}
112
113pub fn ipv6_checksum(
115 data: &[u8],
116 skipword: usize,
117 extra_data: &[u8],
118 source: &Ipv6Addr,
119 destination: &Ipv6Addr,
120 next_level_protocol: IpNextProtocol,
121) -> u16be {
122 let mut sum = 0u32;
123
124 sum += ipv6_word_sum(source);
126 sum += ipv6_word_sum(destination);
127 sum += next_level_protocol as u32;
128
129 let len = data.len() + extra_data.len();
130 sum += len as u32;
131
132 sum += sum_be_words(data, skipword);
134 sum += sum_be_words(extra_data, extra_data.len() / 2);
135
136 finalize_checksum(sum)
137}
138
139fn ipv6_word_sum(ip: &Ipv6Addr) -> u32 {
140 ip.segments().iter().map(|x| *x as u32).sum()
141}
142
143fn sum_be_words(data: &[u8], skipword: usize) -> u32 {
146 if data.len() == 0 {
147 return 0;
148 }
149 let len = data.len();
150 let mut cur_data = &data[..];
151 let mut sum = 0u32;
152 let mut i = 0;
153 while cur_data.len() >= 2 {
154 if i != skipword {
155 sum += ((cur_data[0] as u32) << 8) + cur_data[1] as u32;
156 }
157 cur_data = &cur_data[2..];
158 i += 1;
159 }
160
161 if i != skipword && len & 1 != 0 {
163 sum += (data[len - 1] as u32) << 8;
164 }
165
166 sum
167}
168
169#[cfg(test)]
170mod tests {
171 use super::sum_be_words;
172 use core::slice;
173
174 #[test]
175 fn sum_be_words_different_skipwords() {
176 let data = (0..11).collect::<Vec<u8>>();
177 assert_eq!(7190, sum_be_words(&data, 1));
178 assert_eq!(6676, sum_be_words(&data, 2));
179 assert_eq!(7705, sum_be_words(&data, 99));
182 assert_eq!(7705, sum_be_words(&data, 101));
183 }
184
185 #[test]
186 fn sum_be_words_small_sizes() {
187 let data_zero = vec![0; 0];
188 assert_eq!(0, sum_be_words(&data_zero, 0));
189 assert_eq!(0, sum_be_words(&data_zero, 10));
190 let data_one = vec![1; 1];
191 assert_eq!(0, sum_be_words(&data_zero, 0));
192 assert_eq!(256, sum_be_words(&data_one, 1));
193 let data_two = vec![1; 2];
194 assert_eq!(0, sum_be_words(&data_two, 0));
195 assert_eq!(257, sum_be_words(&data_two, 1));
196 let data_three = vec![4; 3];
197 assert_eq!(1024, sum_be_words(&data_three, 0));
198 assert_eq!(1028, sum_be_words(&data_three, 1));
199 assert_eq!(2052, sum_be_words(&data_three, 2));
200 assert_eq!(2052, sum_be_words(&data_three, 3));
201 }
202
203 #[test]
204 fn sum_be_words_misaligned_ptr() {
205 let mut data = vec![0; 13];
206 let ptr = match data.as_ptr() as usize % 2 {
207 0 => unsafe { data.as_mut_ptr().offset(1) },
208 _ => data.as_mut_ptr(),
209 };
210 unsafe {
211 let slice_data = slice::from_raw_parts_mut(ptr, 12);
212 for i in 0..11 {
213 slice_data[i] = i as u8;
214 }
215 assert_eq!(7190, sum_be_words(&slice_data, 1));
216 assert_eq!(6676, sum_be_words(&slice_data, 2));
217 assert_eq!(7705, sum_be_words(&slice_data, 99));
220 assert_eq!(7705, sum_be_words(&slice_data, 101));
221 }
222 }
223}