Skip to main content

erbium/pktparser/
mod.rs

1/*   Copyright 2024 Perry Lorier
2 *
3 *  Licensed under the Apache License, Version 2.0 (the "License");
4 *  you may not use this file except in compliance with the License.
5 *  You may obtain a copy of the License at
6 *
7 *      http://www.apache.org/licenses/LICENSE-2.0
8 *
9 *  Unless required by applicable law or agreed to in writing, software
10 *  distributed under the License is distributed on an "AS IS" BASIS,
11 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 *  See the License for the specific language governing permissions and
13 *  limitations under the License.
14 *
15 *  SPDX-License-Identifier: Apache-2.0
16 *
17 *  API to make parsing packets easier.
18 */
19
20use std::convert::TryInto as _;
21
22#[derive(Debug, PartialEq, Eq)]
23pub enum ParseError {
24    UnexpectedEndOfInput,
25    InvalidArgument(String),
26}
27
28impl std::error::Error for ParseError {
29    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
30        None
31    }
32}
33
34impl std::fmt::Display for ParseError {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        match self {
37            ParseError::UnexpectedEndOfInput => write!(f, "Unexpected End Of Input"),
38            ParseError::InvalidArgument(string) => write!(f, "InvalidArgument: {}", string),
39        }
40    }
41}
42
43pub struct Buffer<'l> {
44    buffer: &'l [u8],
45    offset: usize,
46}
47
48impl<'l> Buffer<'l> {
49    pub const fn new(buffer: &'l [u8]) -> Buffer<'l> {
50        Buffer { buffer, offset: 0 }
51    }
52
53    pub const fn remaining(&self) -> usize {
54        self.buffer.len() - self.offset
55    }
56
57    pub const fn size(&self) -> usize {
58        self.buffer.len()
59    }
60
61    pub const fn empty(&self) -> bool {
62        self.remaining() == 0
63    }
64
65    pub const fn set_offset(mut self, o: usize) -> Option<Self> {
66        if o <= self.size() {
67            self.offset = o;
68            Some(self)
69        } else {
70            None
71        }
72    }
73
74    pub const fn skip(self, s: usize) -> Option<Self> {
75        let new_offset = self.offset + s;
76        self.set_offset(new_offset)
77    }
78
79    pub fn get_u8(&mut self) -> Option<u8> {
80        if self.offset < self.buffer.len() {
81            let ret = self.buffer[self.offset];
82            self.offset += 1;
83            Some(ret)
84        } else {
85            None
86        }
87    }
88
89    pub fn peek_u8(&self) -> Option<u8> {
90        if self.offset < self.buffer.len() {
91            Some(self.buffer[self.offset])
92        } else {
93            None
94        }
95    }
96
97    pub fn get_bytes(&mut self, b: usize) -> Option<&'l [u8]> {
98        if self.offset + b <= self.buffer.len() {
99            let ret = &self.buffer[self.offset..self.offset + b];
100            self.offset += b;
101            Some(ret)
102        } else {
103            None
104        }
105    }
106
107    pub fn get_buffer(&mut self, b: usize) -> Option<Buffer<'l>> {
108        if self.offset + b <= self.buffer.len() {
109            let ret = Buffer {
110                buffer: &self.buffer[self.offset..self.offset + b],
111                offset: 0,
112            };
113            self.offset += b;
114            Some(ret)
115        } else {
116            None
117        }
118    }
119
120    pub fn get_vec(&mut self, b: usize) -> Option<Vec<u8>> {
121        Some(self.get_bytes(b)?.to_vec())
122    }
123
124    pub fn get_be16(&mut self) -> Option<u16> {
125        let bytes = self.get_bytes(std::mem::size_of::<u16>())?;
126        Some(u16::from_be_bytes(bytes.try_into().unwrap()))
127    }
128
129    pub fn get_be32(&mut self) -> Option<u32> {
130        let bytes = self.get_bytes(std::mem::size_of::<u32>())?;
131        Some(u32::from_be_bytes(bytes.try_into().unwrap()))
132    }
133
134    pub fn get_ipv4(&mut self) -> Option<std::net::Ipv4Addr> {
135        let bytes = self.get_bytes(std::mem::size_of::<[u8; 4]>())?;
136        Some(std::net::Ipv4Addr::new(
137            bytes[0], bytes[1], bytes[2], bytes[3],
138        ))
139    }
140
141    pub fn get_tlv(&mut self) -> Option<(u8, &[u8])> {
142        let tl = self.get_bytes(2)?;
143        let t = tl[0];
144        let l = tl[1];
145        Some((t, self.get_bytes(l as usize)?))
146    }
147
148    fn get_label(&mut self) -> Option<&[u8]> {
149        let l = self.get_u8()?;
150        self.get_bytes(l as usize)
151    }
152
153    fn get_domain(&mut self) -> Option<Vec<String>> {
154        let mut d = vec![];
155        loop {
156            let l = self.get_label()?;
157            if l.is_empty() {
158                return Some(d);
159            }
160            d.push(String::from_utf8_lossy(l).to_string())
161        }
162    }
163
164    pub fn get_domains(&mut self) -> Option<Vec<Vec<String>>> {
165        let mut dl = vec![];
166        while !self.empty() {
167            dl.push(self.get_domain()?);
168        }
169        Some(dl)
170    }
171}
172
173impl<'l> std::fmt::Display for Buffer<'l> {
174    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
175        write!(f, "[")?;
176        for byte in &self.buffer[self.offset..] {
177            write!(f, "{:x} ", byte)?;
178        }
179        write!(f, "]")
180    }
181}
182
183pub trait Serialise {
184    fn to_wire(&self) -> Result<Vec<u8>, std::io::Error>;
185}
186
187pub trait Deserialise {
188    fn from_wire(buf: &mut Buffer<'_>) -> Result<Self, ParseError>
189    where
190        Self: Sized;
191}
192
193#[test]
194fn test_get_u8() {
195    let data = [1, 2, 3];
196    let mut buffer = Buffer::new(&data);
197    assert_eq!(buffer.get_u8(), Some(1));
198    assert_eq!(buffer.get_u8(), Some(2));
199    assert_eq!(buffer.get_u8(), Some(3));
200    assert_eq!(buffer.get_u8(), None);
201}
202
203#[test]
204fn test_get_bytes() {
205    use std::convert::TryFrom as _;
206    let data = [1, 2, 3, 4];
207    let mut buffer = Buffer::new(&data);
208    assert_eq!(
209        <[u8; 2]>::try_from(buffer.get_bytes(2).unwrap()).unwrap(),
210        [1u8, 2]
211    );
212    assert_eq!(
213        <[u8; 2]>::try_from(buffer.get_bytes(2).unwrap()).unwrap(),
214        [3u8, 4]
215    );
216    assert_eq!(buffer.get_bytes(2), None)
217}
218
219#[test]
220fn test_get_vec() {
221    let data = [1, 2, 3, 4];
222    let mut buffer = Buffer::new(&data);
223    assert_eq!(buffer.get_vec(2), Some(vec![1u8, 2]));
224    assert_eq!(buffer.get_vec(2), Some(vec![3u8, 4]));
225    assert_eq!(buffer.get_bytes(2), None)
226}
227
228#[test]
229fn test_get_u16() {
230    let data = [1, 2, 3, 4];
231    let mut buffer = Buffer::new(&data);
232    assert_eq!(buffer.get_be16(), Some(0x0102));
233    assert_eq!(buffer.get_be16(), Some(0x0304));
234    assert_eq!(buffer.get_be16(), None)
235}
236
237#[test]
238fn test_get_u32() {
239    let data = [1, 2, 3, 4, 5, 6, 7, 8];
240    let mut buffer = Buffer::new(&data);
241    assert_eq!(buffer.get_be32(), Some(0x01020304));
242    assert_eq!(buffer.get_be32(), Some(0x05060708));
243    assert_eq!(buffer.get_be32(), None)
244}
245
246#[test]
247fn test_get_ipv4() {
248    let data = [1, 2, 3, 4, 5, 6, 7, 8];
249    let mut buffer = Buffer::new(&data);
250    assert_eq!(buffer.get_ipv4(), Some(std::net::Ipv4Addr::new(1, 2, 3, 4)));
251    assert_eq!(buffer.get_ipv4(), Some(std::net::Ipv4Addr::new(5, 6, 7, 8)));
252    assert_eq!(buffer.get_ipv4(), None)
253}
254
255#[test]
256fn test_size() {
257    let data = [1, 2, 3, 4, 5, 6, 7, 8];
258    let mut buffer = Buffer::new(&data);
259    assert_eq!(buffer.size(), 8);
260    buffer.get_u8();
261    assert_eq!(buffer.size(), 8);
262}
263
264#[test]
265fn test_remaining() {
266    let data = [1, 2, 3, 4, 5, 6, 7, 8];
267    let mut buffer = Buffer::new(&data);
268    assert_eq!(buffer.remaining(), 8);
269    buffer.get_u8();
270    assert_eq!(buffer.remaining(), 7);
271}
272
273#[test]
274fn test_domains() {
275    let data = [
276        3, 0x77, 0x77, 0x77, 7, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 3, 0x63, 0x6f, 0x6d, 0,
277        3, 0x77, 0x77, 0x77, 7, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 3, 0x6f, 0x72, 0x67, 0,
278    ];
279    let mut buf = Buffer::new(&data);
280    assert_eq!(
281        buf.get_domains(),
282        Some(vec![
283            vec!["www".into(), "example".into(), "com".into()],
284            vec!["www".into(), "example".into(), "org".into()]
285        ])
286    );
287}