1use crate::Error;
2use crate::layer3::InternetProtocolId;
3
4use arrayref::array_ref;
5use byteorder::{BigEndian as BE, WriteBytesExt};
6use log::*;
7use nom::*;
8use std::mem::size_of;
9use std::net::IpAddr;
10use std::io::{Cursor, Write};
11
12const ADDRESS_LENGTH: usize = 4;
13
14#[derive(Clone, Copy, Debug)]
15pub struct IPv4<'a> {
16 pub version_and_length: u8,
17 pub tos: u8,
18 pub raw_length: u16,
19 pub id: u16,
20 pub flags: u16,
21 pub ttl: u8,
22 pub protocol: InternetProtocolId,
23 pub checksum: u16,
24 pub src_ip: IpAddr,
25 pub dst_ip: IpAddr,
26 pub payload: &'a [u8],
27 pub options: Option<&'a [u8]>,
28 pub padding: Option<&'a [u8]>,
29}
30
31fn to_ip_address(i: &[u8]) -> IpAddr {
32 let ipv4 = std::net::Ipv4Addr::from(array_ref![i, 0, ADDRESS_LENGTH].clone());
33 IpAddr::V4(ipv4)
34}
35
36named!(
37 ipv4_address<&[u8], std::net::IpAddr>,
38 map!(take!(ADDRESS_LENGTH), to_ip_address)
39);
40
41impl<'a> IPv4<'a> {
42 pub fn as_bytes(&self) -> Vec<u8> {
43 let inner = Vec::with_capacity(
44 size_of::<u8>() * 4
45 + size_of::<u16>() * 4
46 + 4 * 2
47 + self.payload.len()
48 + self.options.map(|i| i.len()).unwrap_or(0)
49 + self.padding.map(|i| i.len()).unwrap_or(0)
50 );
51 let mut writer = Cursor::new(inner);
52 writer.write_u8(self.version_and_length).unwrap();
53 writer.write_u8(self.tos).unwrap();
54 writer.write_u16::<BE>(self.raw_length).unwrap();
55 writer.write_u16::<BE>(self.id).unwrap();
56 writer.write_u16::<BE>(self.flags).unwrap();
57 writer.write_u8(self.ttl).unwrap();
58 writer.write_u8(self.protocol.value()).unwrap();
59 writer.write_u16::<BE>(self.checksum).unwrap();
60 if let IpAddr::V4(v) = self.src_ip {
61 writer.write(&v.octets()).unwrap();
62 }
63 if let IpAddr::V4(v) = self.dst_ip {
64 writer.write(&v.octets()).unwrap();
65 }
66 writer.write(self.payload).unwrap();
67 if let Some(i) = self.options {
68 writer.write(i).unwrap();
69 }
70 if let Some(i) = self.padding {
71 writer.write(i).unwrap();
72 }
73 writer.into_inner()
74 }
75
76 fn parse_ipv4<'b>(
77 input: &'b [u8],
78 input_length: usize,
79 version_and_length: u8,
80 ) -> IResult<&'b [u8], IPv4<'b>> {
81 let header_words = version_and_length & 0x0F;
82 let header_length = header_words * 4;
83 let additional_length = if header_words > 5 {
84 (header_words - 5) * 4
85 } else {
86 0
87 };
88
89 trace!(
90 "Input Length={} Header Length={} Additional Length={}",
91 input_length,
92 header_length,
93 additional_length
94 );
95
96 let (rem, (tos, (raw_length, length))) = do_parse!(
97 input,
98 tos: be_u8
99 >> lengths: map!(be_u16, |s| {
100 let l = s - (header_length as u16);
101 trace!("Payload Length={}", l);
102 (s, l)
103 })
104 >> ((tos, lengths))
105 )?;
106
107 let expected_length = header_length as usize + additional_length as usize + length as usize;
108 trace!(
109 "Input had length {}B, expected {}B",
110 input_length,
111 expected_length
112 );
113
114 do_parse!(
115 rem,
116 id: be_u16
117 >> flags: be_u16
118 >> ttl: be_u8
119 >> protocol: map_opt!(be_u8, InternetProtocolId::new)
120 >> checksum: be_u16
121 >> src_ip: ipv4_address
122 >> dst_ip: ipv4_address
123 >> payload: take!(length)
124 >> options: cond!(additional_length > 0, take!(additional_length))
125 >> padding:
126 cond!(
127 input_length > expected_length,
128 take!(input_length - expected_length)
129 )
130 >> (IPv4 {
131 version_and_length,
132 tos,
133 raw_length,
134 id,
135 flags,
136 ttl,
137 protocol,
138 checksum,
139 src_ip,
140 dst_ip,
141 payload,
142 options,
143 padding,
144 })
145 )
146 }
147
148 pub fn parse<'b>(input: &'b [u8]) -> Result<(&'b [u8], IPv4<'b>), Error> {
149 let input_len = input.len();
150
151 be_u8(input).map_err(Error::from).and_then(|r| {
152 let (rem, version_and_length) = r;
153 let version = version_and_length >> 4;
154 if version == 4 {
155 IPv4::parse_ipv4(rem, input_len, version_and_length).map_err(Error::from)
156 } else {
157 Err(Error::Custom { msg: format!("Expected version 4, was {}", version) } )
158 }
159 })
160 }
161}
162
163#[cfg(test)]
164pub mod tests {
165 use super::*;
166
167 pub const RAW_DATA: &'static [u8] = &[
168 0x45u8, 0x00u8, 0x00u8, 0x48u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x64u8, 0x06u8, 0x00u8, 0x00u8, 0x01u8, 0x02u8, 0x03u8, 0x04u8, 0x0Au8, 0x0Bu8, 0x0Cu8, 0x0Du8, 0xC6u8, 0xB7u8, 0x00u8, 0x50u8, 0x00u8, 0x00u8, 0x00u8, 0x01u8, 0x00u8, 0x00u8, 0x00u8, 0x02u8, 0x50u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x01u8, 0x02u8, 0x03u8, 0x04u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8,
190 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8,
191 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0xfcu8, 0xfdu8, 0xfeu8,
192 0xffu8, ];
194
195 #[test]
196 fn parse_ipv4() {
197 let _ = env_logger::try_init();
198
199 let (rem, l3) = IPv4::parse(RAW_DATA).expect("Unable to parse");
200
201 assert!(rem.is_empty());
202 assert_eq!(
203 l3.src_ip,
204 "1.2.3.4"
205 .parse::<std::net::IpAddr>()
206 .expect("Could not parse ip address")
207 );
208 assert_eq!(
209 l3.dst_ip,
210 "10.11.12.13"
211 .parse::<std::net::IpAddr>()
212 .expect("Could not parse ip address")
213 );
214
215 let is_tcp = if let InternetProtocolId::Tcp = l3.protocol {
216 true
217 } else {
218 false
219 };
220
221 assert!(is_tcp);
222
223 assert_eq!(l3.as_bytes().as_slice(), RAW_DATA);
224 }
225}