1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
use crate::AsBeBytes;
use super::{Header, TransportHeader, PacketData, Protocol, ParseError, PseudoHeader};
#[derive(AddGetter, AddSetter)]
pub struct UdpHeader {
#[get] #[set] src_port: u16,
#[get] #[set] dst_port: u16,
#[get] #[set] length: u16,
#[get] #[set] checksum: u16,
pseudo_header: Option<PseudoHeader>,
pseudo_header_set: bool
}
impl UdpHeader {
pub fn new(src_port: u16, dst_port: u16) -> Self {
UdpHeader {
src_port: src_port,
dst_port: dst_port,
length: 8,
checksum: 0,
pseudo_header: None,
pseudo_header_set: false,
}
}
}
impl TransportHeader for UdpHeader {
fn set_pseudo_header(&mut self, src_ip: [u8; 4], dst_ip: [u8; 4], data_len: u16) {
if data_len > (0xffff - 8) {
panic!("too much data");
}
self.length += data_len as u16;
self.pseudo_header = Some(PseudoHeader {
src_ip,
dst_ip,
protocol: 17,
data_len: (data_len + 8) as u16,
});
self.pseudo_header_set = true;
}
}
impl Header for UdpHeader {
fn make(self) -> PacketData {
let src_p = self.src_port.split_to_bytes();
let dst_p = self.dst_port.split_to_bytes();
let length_bytes = self.length.split_to_bytes();
let mut packet = vec![
src_p[0],
src_p[1],
dst_p[0],
dst_p[1],
length_bytes[0],
length_bytes[1],
0,
0
];
if let None = self.pseudo_header {
panic!("Please set the pseudo header data before calculating the checksum");
}
let pseudo_header = self.pseudo_header.unwrap();
let mut val = 0u32;
val += ip_sum(pseudo_header.src_ip);
val += ip_sum(pseudo_header.dst_ip);
val += pseudo_header.protocol as u32;
val += pseudo_header.data_len as u32;
val += pseudo_header.data_len as u32;
let checksum = finalize_checksum(val).split_to_bytes();
packet[6] = checksum[0];
packet[7] = checksum[1];
packet
}
fn parse(raw_data: &[u8]) -> Result<Box<Self>, ParseError> {
if raw_data.len() < Self::get_min_length().into() {
return Err(ParseError::InvalidLength);
}
Ok(Box::new(Self {
src_port: ((raw_data[0] as u16) << 8) + raw_data[1] as u16,
dst_port: ((raw_data[2] as u16) << 8) + raw_data[3] as u16,
length: ((raw_data[4] as u16) << 8) + raw_data[5] as u16,
checksum: ((raw_data[6] as u16) << 8) + raw_data[7] as u16,
pseudo_header: None,
pseudo_header_set: false
}))
}
fn get_proto(&self) -> Protocol {
Protocol::UDP
}
fn get_length(&self) -> u8 {
8
}
fn get_min_length() -> u8 {
8
}
fn into_transport_header(&mut self) -> Option<&mut dyn TransportHeader> {
Some(self)
}
}
#[inline(always)]
fn ip_sum(octets: [u8; 4]) -> u32 {
((octets[0] as u32) << 8 | octets[1] as u32) + ((octets[2] as u32) << 8 | octets[3] as u32)
}
#[inline]
fn finalize_checksum(mut cs: u32) -> u16 {
while cs >> 16 != 0 {
cs = (cs >> 16) + (cs & 0xFFFF);
}
!cs as u16
}