net_parser_rs/layer3/
ipv4.rs

1use crate::Error;
2use crate::layer3::InternetProtocolId;
3
4use arrayref::array_ref;
5use byteorder::{BigEndian as BE, WriteBytesExt};
6use log::*;
7use nom::*;
8use std::mem::size_of;
9use std::net::IpAddr;
10use std::io::{Cursor, Write};
11
12const ADDRESS_LENGTH: usize = 4;
13
14#[derive(Clone, Copy, Debug)]
15pub struct IPv4<'a> {
16    pub version_and_length: u8,
17    pub tos: u8,
18    pub raw_length: u16,
19    pub id: u16,
20    pub flags: u16,
21    pub ttl: u8,
22    pub protocol: InternetProtocolId,
23    pub checksum: u16,
24    pub src_ip: IpAddr,
25    pub dst_ip: IpAddr,
26    pub payload: &'a [u8],
27    pub options: Option<&'a [u8]>,
28    pub padding: Option<&'a [u8]>,
29}
30
31fn to_ip_address(i: &[u8]) -> IpAddr {
32    let ipv4 = std::net::Ipv4Addr::from(array_ref![i, 0, ADDRESS_LENGTH].clone());
33    IpAddr::V4(ipv4)
34}
35
36named!(
37    ipv4_address<&[u8], std::net::IpAddr>,
38    map!(take!(ADDRESS_LENGTH), to_ip_address)
39);
40
41impl<'a> IPv4<'a> {
42    pub fn as_bytes(&self) -> Vec<u8> {
43        let inner = Vec::with_capacity(
44            size_of::<u8>() * 4
45            + size_of::<u16>() * 4
46            + 4 * 2
47            + self.payload.len()
48            + self.options.map(|i| i.len()).unwrap_or(0)
49            + self.padding.map(|i| i.len()).unwrap_or(0)
50        );
51        let mut writer = Cursor::new(inner);
52        writer.write_u8(self.version_and_length).unwrap();
53        writer.write_u8(self.tos).unwrap();
54        writer.write_u16::<BE>(self.raw_length).unwrap();
55        writer.write_u16::<BE>(self.id).unwrap();
56        writer.write_u16::<BE>(self.flags).unwrap();
57        writer.write_u8(self.ttl).unwrap();
58        writer.write_u8(self.protocol.value()).unwrap();
59        writer.write_u16::<BE>(self.checksum).unwrap();
60        if let IpAddr::V4(v) = self.src_ip {
61            writer.write(&v.octets()).unwrap();
62        }
63        if let IpAddr::V4(v) = self.dst_ip {
64            writer.write(&v.octets()).unwrap();
65        }
66        writer.write(self.payload).unwrap();
67        if let Some(i) = self.options {
68            writer.write(i).unwrap();
69        }
70        if let Some(i) = self.padding {
71            writer.write(i).unwrap();
72        }
73        writer.into_inner()
74    }
75
76    fn parse_ipv4<'b>(
77        input: &'b [u8],
78        input_length: usize,
79        version_and_length: u8,
80    ) -> IResult<&'b [u8], IPv4<'b>> {
81        let header_words = version_and_length & 0x0F;
82        let header_length = header_words * 4;
83        let additional_length = if header_words > 5 {
84            (header_words - 5) * 4
85        } else {
86            0
87        };
88
89        trace!(
90            "Input Length={}   Header Length={}   Additional Length={}",
91            input_length,
92            header_length,
93            additional_length
94        );
95
96        let (rem, (tos, (raw_length, length))) = do_parse!(
97            input,
98            tos: be_u8
99                >> lengths: map!(be_u16, |s| {
100                    let l = s - (header_length as u16);
101                    trace!("Payload Length={}", l);
102                    (s, l)
103                })
104                >> ((tos, lengths))
105        )?;
106
107        let expected_length = header_length as usize + additional_length as usize + length as usize;
108        trace!(
109            "Input had length {}B, expected {}B",
110            input_length,
111            expected_length
112        );
113
114        do_parse!(
115            rem,
116            id: be_u16
117                >> flags: be_u16
118                >> ttl: be_u8
119                >> protocol: map_opt!(be_u8, InternetProtocolId::new)
120                >> checksum: be_u16
121                >> src_ip: ipv4_address
122                >> dst_ip: ipv4_address
123                >> payload: take!(length)
124                >> options: cond!(additional_length > 0, take!(additional_length))
125                >> padding:
126                    cond!(
127                        input_length > expected_length,
128                        take!(input_length - expected_length)
129                    )
130                >> (IPv4 {
131                    version_and_length,
132                    tos,
133                    raw_length,
134                    id,
135                    flags,
136                    ttl,
137                    protocol,
138                    checksum,
139                    src_ip,
140                    dst_ip,
141                    payload,
142                    options,
143                    padding,
144                })
145        )
146    }
147
148    pub fn parse<'b>(input: &'b [u8]) -> Result<(&'b [u8], IPv4<'b>), Error> {
149        let input_len = input.len();
150
151        be_u8(input).map_err(Error::from).and_then(|r| {
152            let (rem, version_and_length) = r;
153            let version = version_and_length >> 4;
154            if version == 4 {
155                IPv4::parse_ipv4(rem, input_len, version_and_length).map_err(Error::from)
156            } else {
157                Err(Error::Custom { msg: format!("Expected version 4, was {}", version) } )
158            }
159        })
160    }
161}
162
163#[cfg(test)]
164pub mod tests {
165    use super::*;
166
167    pub const RAW_DATA: &'static [u8] = &[
168        0x45u8, //version and header length
169        0x00u8, //tos
170        0x00u8, 0x48u8, //length, 20 bytes for header, 52 bytes for ethernet
171        0x00u8, 0x00u8, //id
172        0x00u8, 0x00u8, //flags
173        0x64u8, //ttl
174        0x06u8, //protocol, tcp
175        0x00u8, 0x00u8, //checksum
176        0x01u8, 0x02u8, 0x03u8, 0x04u8, //src ip 1.2.3.4
177        0x0Au8, 0x0Bu8, 0x0Cu8, 0x0Du8, //dst ip 10.11.12.13
178        //tcp
179        0xC6u8, 0xB7u8, //src port, 50871
180        0x00u8, 0x50u8, //dst port, 80
181        0x00u8, 0x00u8, 0x00u8, 0x01u8, //sequence number, 1
182        0x00u8, 0x00u8, 0x00u8, 0x02u8, //acknowledgement number, 2
183        0x50u8, 0x00u8, //header and flags, 0
184        0x00u8, 0x00u8, //window
185        0x00u8, 0x00u8, //check
186        0x00u8, 0x00u8, //urgent
187        //no options
188        //payload
189        0x01u8, 0x02u8, 0x03u8, 0x04u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8,
190        0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8,
191        0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0xfcu8, 0xfdu8, 0xfeu8,
192        0xffu8, //payload, 8 words
193    ];
194
195    #[test]
196    fn parse_ipv4() {
197        let _ = env_logger::try_init();
198
199        let (rem, l3) = IPv4::parse(RAW_DATA).expect("Unable to parse");
200
201        assert!(rem.is_empty());
202        assert_eq!(
203            l3.src_ip,
204            "1.2.3.4"
205                .parse::<std::net::IpAddr>()
206                .expect("Could not parse ip address")
207        );
208        assert_eq!(
209            l3.dst_ip,
210            "10.11.12.13"
211                .parse::<std::net::IpAddr>()
212                .expect("Could not parse ip address")
213        );
214
215        let is_tcp = if let InternetProtocolId::Tcp = l3.protocol {
216            true
217        } else {
218            false
219        };
220
221        assert!(is_tcp);
222
223        assert_eq!(l3.as_bytes().as_slice(), RAW_DATA);
224    }
225}