1use std::convert::TryInto as _;
21
22#[derive(Debug, PartialEq, Eq)]
23pub enum ParseError {
24 UnexpectedEndOfInput,
25 InvalidArgument(String),
26}
27
28impl std::error::Error for ParseError {
29 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
30 None
31 }
32}
33
34impl std::fmt::Display for ParseError {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 match self {
37 ParseError::UnexpectedEndOfInput => write!(f, "Unexpected End Of Input"),
38 ParseError::InvalidArgument(string) => write!(f, "InvalidArgument: {}", string),
39 }
40 }
41}
42
43pub struct Buffer<'l> {
44 buffer: &'l [u8],
45 offset: usize,
46}
47
48impl<'l> Buffer<'l> {
49 pub const fn new(buffer: &'l [u8]) -> Buffer<'l> {
50 Buffer { buffer, offset: 0 }
51 }
52
53 pub const fn remaining(&self) -> usize {
54 self.buffer.len() - self.offset
55 }
56
57 pub const fn size(&self) -> usize {
58 self.buffer.len()
59 }
60
61 pub const fn empty(&self) -> bool {
62 self.remaining() == 0
63 }
64
65 pub const fn set_offset(mut self, o: usize) -> Option<Self> {
66 if o <= self.size() {
67 self.offset = o;
68 Some(self)
69 } else {
70 None
71 }
72 }
73
74 pub const fn skip(self, s: usize) -> Option<Self> {
75 let new_offset = self.offset + s;
76 self.set_offset(new_offset)
77 }
78
79 pub fn get_u8(&mut self) -> Option<u8> {
80 if self.offset < self.buffer.len() {
81 let ret = self.buffer[self.offset];
82 self.offset += 1;
83 Some(ret)
84 } else {
85 None
86 }
87 }
88
89 pub fn peek_u8(&self) -> Option<u8> {
90 if self.offset < self.buffer.len() {
91 Some(self.buffer[self.offset])
92 } else {
93 None
94 }
95 }
96
97 pub fn get_bytes(&mut self, b: usize) -> Option<&'l [u8]> {
98 if self.offset + b <= self.buffer.len() {
99 let ret = &self.buffer[self.offset..self.offset + b];
100 self.offset += b;
101 Some(ret)
102 } else {
103 None
104 }
105 }
106
107 pub fn get_buffer(&mut self, b: usize) -> Option<Buffer<'l>> {
108 if self.offset + b <= self.buffer.len() {
109 let ret = Buffer {
110 buffer: &self.buffer[self.offset..self.offset + b],
111 offset: 0,
112 };
113 self.offset += b;
114 Some(ret)
115 } else {
116 None
117 }
118 }
119
120 pub fn get_vec(&mut self, b: usize) -> Option<Vec<u8>> {
121 Some(self.get_bytes(b)?.to_vec())
122 }
123
124 pub fn get_be16(&mut self) -> Option<u16> {
125 let bytes = self.get_bytes(std::mem::size_of::<u16>())?;
126 Some(u16::from_be_bytes(bytes.try_into().unwrap()))
127 }
128
129 pub fn get_be32(&mut self) -> Option<u32> {
130 let bytes = self.get_bytes(std::mem::size_of::<u32>())?;
131 Some(u32::from_be_bytes(bytes.try_into().unwrap()))
132 }
133
134 pub fn get_ipv4(&mut self) -> Option<std::net::Ipv4Addr> {
135 let bytes = self.get_bytes(std::mem::size_of::<[u8; 4]>())?;
136 Some(std::net::Ipv4Addr::new(
137 bytes[0], bytes[1], bytes[2], bytes[3],
138 ))
139 }
140
141 pub fn get_tlv(&mut self) -> Option<(u8, &[u8])> {
142 let tl = self.get_bytes(2)?;
143 let t = tl[0];
144 let l = tl[1];
145 Some((t, self.get_bytes(l as usize)?))
146 }
147
148 fn get_label(&mut self) -> Option<&[u8]> {
149 let l = self.get_u8()?;
150 self.get_bytes(l as usize)
151 }
152
153 fn get_domain(&mut self) -> Option<Vec<String>> {
154 let mut d = vec![];
155 loop {
156 let l = self.get_label()?;
157 if l.is_empty() {
158 return Some(d);
159 }
160 d.push(String::from_utf8_lossy(l).to_string())
161 }
162 }
163
164 pub fn get_domains(&mut self) -> Option<Vec<Vec<String>>> {
165 let mut dl = vec![];
166 while !self.empty() {
167 dl.push(self.get_domain()?);
168 }
169 Some(dl)
170 }
171}
172
173impl<'l> std::fmt::Display for Buffer<'l> {
174 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
175 write!(f, "[")?;
176 for byte in &self.buffer[self.offset..] {
177 write!(f, "{:x} ", byte)?;
178 }
179 write!(f, "]")
180 }
181}
182
183pub trait Serialise {
184 fn to_wire(&self) -> Result<Vec<u8>, std::io::Error>;
185}
186
187pub trait Deserialise {
188 fn from_wire(buf: &mut Buffer<'_>) -> Result<Self, ParseError>
189 where
190 Self: Sized;
191}
192
193#[test]
194fn test_get_u8() {
195 let data = [1, 2, 3];
196 let mut buffer = Buffer::new(&data);
197 assert_eq!(buffer.get_u8(), Some(1));
198 assert_eq!(buffer.get_u8(), Some(2));
199 assert_eq!(buffer.get_u8(), Some(3));
200 assert_eq!(buffer.get_u8(), None);
201}
202
203#[test]
204fn test_get_bytes() {
205 use std::convert::TryFrom as _;
206 let data = [1, 2, 3, 4];
207 let mut buffer = Buffer::new(&data);
208 assert_eq!(
209 <[u8; 2]>::try_from(buffer.get_bytes(2).unwrap()).unwrap(),
210 [1u8, 2]
211 );
212 assert_eq!(
213 <[u8; 2]>::try_from(buffer.get_bytes(2).unwrap()).unwrap(),
214 [3u8, 4]
215 );
216 assert_eq!(buffer.get_bytes(2), None)
217}
218
219#[test]
220fn test_get_vec() {
221 let data = [1, 2, 3, 4];
222 let mut buffer = Buffer::new(&data);
223 assert_eq!(buffer.get_vec(2), Some(vec![1u8, 2]));
224 assert_eq!(buffer.get_vec(2), Some(vec![3u8, 4]));
225 assert_eq!(buffer.get_bytes(2), None)
226}
227
228#[test]
229fn test_get_u16() {
230 let data = [1, 2, 3, 4];
231 let mut buffer = Buffer::new(&data);
232 assert_eq!(buffer.get_be16(), Some(0x0102));
233 assert_eq!(buffer.get_be16(), Some(0x0304));
234 assert_eq!(buffer.get_be16(), None)
235}
236
237#[test]
238fn test_get_u32() {
239 let data = [1, 2, 3, 4, 5, 6, 7, 8];
240 let mut buffer = Buffer::new(&data);
241 assert_eq!(buffer.get_be32(), Some(0x01020304));
242 assert_eq!(buffer.get_be32(), Some(0x05060708));
243 assert_eq!(buffer.get_be32(), None)
244}
245
246#[test]
247fn test_get_ipv4() {
248 let data = [1, 2, 3, 4, 5, 6, 7, 8];
249 let mut buffer = Buffer::new(&data);
250 assert_eq!(buffer.get_ipv4(), Some(std::net::Ipv4Addr::new(1, 2, 3, 4)));
251 assert_eq!(buffer.get_ipv4(), Some(std::net::Ipv4Addr::new(5, 6, 7, 8)));
252 assert_eq!(buffer.get_ipv4(), None)
253}
254
255#[test]
256fn test_size() {
257 let data = [1, 2, 3, 4, 5, 6, 7, 8];
258 let mut buffer = Buffer::new(&data);
259 assert_eq!(buffer.size(), 8);
260 buffer.get_u8();
261 assert_eq!(buffer.size(), 8);
262}
263
264#[test]
265fn test_remaining() {
266 let data = [1, 2, 3, 4, 5, 6, 7, 8];
267 let mut buffer = Buffer::new(&data);
268 assert_eq!(buffer.remaining(), 8);
269 buffer.get_u8();
270 assert_eq!(buffer.remaining(), 7);
271}
272
273#[test]
274fn test_domains() {
275 let data = [
276 3, 0x77, 0x77, 0x77, 7, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 3, 0x63, 0x6f, 0x6d, 0,
277 3, 0x77, 0x77, 0x77, 7, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 3, 0x6f, 0x72, 0x67, 0,
278 ];
279 let mut buf = Buffer::new(&data);
280 assert_eq!(
281 buf.get_domains(),
282 Some(vec![
283 vec!["www".into(), "example".into(), "com".into()],
284 vec!["www".into(), "example".into(), "org".into()]
285 ])
286 );
287}