1use base64::{prelude::*, DecodeError};
4use hex::FromHexError;
5use std::{
6 convert::TryInto,
7 fmt::Debug,
8 io::{Cursor, Error as IOError, Read, Write},
9 vec,
10};
11use thiserror::Error;
12
13#[derive(Debug, Error)]
15pub enum SerError {
16 #[error("Attempted to deserialize non-minmal VarInt. Someone is doing something fishy.")]
18 NonMinimalVarInt,
19
20 #[error(transparent)]
22 IoError(#[from] IOError),
23
24 #[error(transparent)]
26 FromHexError(#[from] FromHexError),
27
28 #[error(transparent)]
30 DecodeError(#[from] DecodeError),
31
32 #[error("Error in component (de)serialization: {0}")]
34 ComponentError(String),
35
36 #[error("Expected a sequence of exaclty {expected} items. Got only {got} items")]
38 InsufficientSeqItems {
39 expected: usize,
41 got: usize,
43 },
44}
45
46pub enum ReadSeqMode {
48 Exactly(usize),
50 AtMost(usize),
52 UntilEnd,
54}
55
56pub type SerResult<T> = Result<T, SerError>;
58
59pub fn prefix_byte_len(number: u64) -> u8 {
61 match number {
62 0..=0xfc => 1,
63 0xfd..=0xffff => 3,
64 0x10000..=0xffff_ffff => 5,
65 _ => 9,
66 }
67}
68
69pub fn first_byte_from_len(number: u8) -> Option<u8> {
71 match number {
72 3 => Some(0xfd),
73 5 => Some(0xfe),
74 9 => Some(0xff),
75 _ => None,
76 }
77}
78
79pub fn prefix_len_from_first_byte(number: u8) -> u8 {
81 match number {
82 0..=0xfc => 1,
83 0xfd => 3,
84 0xfe => 5,
85 0xff => 9,
86 }
87}
88
89pub fn write_compact_int<W>(writer: &mut W, number: u64) -> SerResult<usize>
91where
92 W: Write,
93{
94 let prefix_len = prefix_byte_len(number);
95 let written: usize = match first_byte_from_len(prefix_len) {
96 None => writer.write(&[number as u8])?,
97 Some(prefix) => {
98 let mut written = writer.write(&[prefix])?;
99 let body = number.to_le_bytes();
100 written += writer.write(&body[..prefix_len as usize - 1])?;
101 written
102 }
103 };
104 Ok(written)
105}
106
107pub fn read_compact_int<R>(reader: &mut R) -> SerResult<u64>
109where
110 R: Read,
111{
112 let mut prefix = [0u8; 1];
113 reader.read_exact(&mut prefix)?; let prefix_len = prefix_len_from_first_byte(prefix[0]);
115
116 let number = if prefix_len > 1 {
118 let mut buf = [0u8; 8];
119 let mut body = reader.take(prefix_len as u64 - 1); let _ = body.read(&mut buf)?;
121 u64::from_le_bytes(buf)
122 } else {
123 prefix[0] as u64
124 };
125
126 let minimal_length = prefix_byte_len(number);
127 if minimal_length < prefix_len {
128 Err(SerError::NonMinimalVarInt)
129 } else {
130 Ok(number)
131 }
132}
133
134pub fn read_u32_le<R>(reader: &mut R) -> SerResult<u32>
136where
137 R: Read,
138{
139 let mut buf = [0u8; 4];
140 reader.read_exact(&mut buf)?;
141 Ok(u32::from_le_bytes(buf))
142}
143
144pub fn write_u32_le<W>(writer: &mut W, number: u32) -> SerResult<usize>
146where
147 W: Write,
148{
149 Ok(writer.write(&number.to_le_bytes())?)
150}
151
152pub fn read_u64_le<R>(reader: &mut R) -> SerResult<u64>
154where
155 R: Read,
156{
157 let mut buf = [0u8; 8];
158 reader.read_exact(&mut buf)?;
159 Ok(u64::from_le_bytes(buf))
160}
161
162pub fn write_u64_le<W>(writer: &mut W, number: u64) -> SerResult<usize>
164where
165 W: Write,
166{
167 Ok(writer.write(&number.to_le_bytes())?)
168}
169
170pub fn read_prefix_vec<R, E, I>(reader: &mut R) -> Result<Vec<I>, E>
172where
173 R: Read,
174 E: From<SerError> + From<IOError> + std::error::Error,
175 I: ByteFormat<Error = E>,
176{
177 let items = read_compact_int(reader)?;
178 I::read_seq_from(reader, ReadSeqMode::Exactly(items.try_into().unwrap()))
179}
180
181pub fn write_prefix_vec<W, E, I>(writer: &mut W, vector: &[I]) -> Result<usize, E>
183where
184 W: Write,
185 E: From<SerError> + From<IOError> + std::error::Error,
186 I: ByteFormat<Error = E>,
187{
188 let mut written = write_compact_int(writer, vector.len() as u64)?;
189 written += I::write_seq_to(writer, vector.iter())?;
190 Ok(written)
191}
192
193pub trait ByteFormat {
198 type Error: From<SerError> + From<IOError> + std::error::Error;
200
201 fn serialized_length(&self) -> usize;
203
204 fn read_from<R>(reader: &mut R) -> Result<Self, Self::Error>
218 where
219 R: Read,
220 Self: std::marker::Sized;
221
222 fn write_to<W>(&self, writer: &mut W) -> Result<usize, <Self as ByteFormat>::Error>
238 where
239 W: Write;
240
241 fn read_seq_from<R>(reader: &mut R, mode: ReadSeqMode) -> Result<Vec<Self>, Self::Error>
244 where
245 R: Read,
246 Self: std::marker::Sized,
247 {
248 let mut v = vec![];
249 match mode {
250 ReadSeqMode::Exactly(number) => {
251 for _ in 0..number {
252 v.push(Self::read_from(reader)?);
253 }
254 if v.len() != number {
255 return Err(SerError::InsufficientSeqItems {
256 got: v.len(),
257 expected: number,
258 }
259 .into());
260 }
261 }
262 ReadSeqMode::AtMost(limit) => {
263 for _ in 0..limit {
264 v.push(Self::read_from(reader)?);
265 }
266 }
267 ReadSeqMode::UntilEnd => {
268 while let Ok(obj) = Self::read_from(reader) {
269 v.push(obj);
270 }
271 }
272 }
273 Ok(v)
274 }
275
276 fn write_seq_to<'a, W, E, Iter, Item>(
306 writer: &mut W,
307 iter: Iter,
308 ) -> Result<usize, <Self as ByteFormat>::Error>
309 where
310 W: Write,
311 E: Into<Self::Error> + From<SerError> + From<IOError> + std::error::Error,
312 Item: 'a + ByteFormat<Error = E>,
313 Iter: IntoIterator<Item = &'a Item>,
314 {
315 let mut written = 0;
316 for item in iter {
317 written += item.write_to(writer).map_err(Into::into)?;
318 }
319 Ok(written)
320 }
321
322 fn deserialize_hex(s: &str) -> Result<Self, Self::Error>
324 where
325 Self: std::marker::Sized,
326 {
327 let v: Vec<u8> = hex::decode(s).map_err(SerError::from)?;
328 let mut cursor = Cursor::new(v);
329 Self::read_from(&mut cursor)
330 }
331
332 fn deserialize_base64(s: &str) -> Result<Self, Self::Error>
334 where
335 Self: std::marker::Sized,
336 {
337 let v: Vec<u8> = BASE64_STANDARD.decode(s).map_err(SerError::from)?;
338 let mut cursor = Cursor::new(v);
339 Self::read_from(&mut cursor)
340 }
341
342 fn serialize_hex(&self) -> String {
344 let mut v: Vec<u8> = vec![];
345 self.write_to(&mut v).expect("No error on heap write");
346 hex::encode(v)
347 }
348
349 fn serialize_base64(&self) -> String {
351 let mut v: Vec<u8> = vec![];
352 self.write_to(&mut v).expect("No error on heap write");
353 BASE64_STANDARD.encode(v)
354 }
355}
356
357impl ByteFormat for u8 {
358 type Error = SerError;
359
360 fn serialized_length(&self) -> usize {
361 1
362 }
363
364 fn read_seq_from<R>(reader: &mut R, mode: ReadSeqMode) -> SerResult<Vec<u8>>
365 where
366 R: Read,
367 Self: std::marker::Sized,
368 {
369 match mode {
370 ReadSeqMode::Exactly(number) => {
371 let mut v = vec![0u8; number];
372 reader.read_exact(v.as_mut_slice())?;
373 Ok(v)
374 }
375 ReadSeqMode::AtMost(limit) => {
376 let mut v = vec![0u8; limit];
377 let n = reader.read(v.as_mut_slice())?;
378 v.truncate(n);
379 Ok(v)
380 }
381 ReadSeqMode::UntilEnd => {
382 let mut buf = vec![];
383 reader.read_to_end(&mut buf)?;
384 Ok(buf)
385 }
386 }
387 }
388
389 fn read_from<R>(reader: &mut R) -> SerResult<Self>
390 where
391 R: Read,
392 Self: std::marker::Sized,
393 {
394 let mut buf = [0u8; 1];
395 reader.read_exact(&mut buf)?;
396 Ok(u8::from_le_bytes(buf))
397 }
398
399 fn write_to<W>(&self, writer: &mut W) -> SerResult<usize>
400 where
401 W: Write,
402 {
403 Ok(writer.write(&self.to_le_bytes())?)
404 }
405}
406
407#[cfg(test)]
408mod test {
409 use super::*;
410
411 #[test]
412 fn it_matches_byte_len_and_prefix() {
413 let cases = [
414 (1, 1, None),
415 (0xff, 3, Some(0xfd)),
416 (0xffff_ffff, 5, Some(0xfe)),
417 (0xffff_ffff_ffff_ffff, 9, Some(0xff)),
418 ];
419 for case in cases.iter() {
420 assert_eq!(prefix_byte_len(case.0), case.1);
421 assert_eq!(first_byte_from_len(case.1), case.2);
422 }
423 }
424
425 #[test]
426 fn it_implements_byteformat_for_u8() {
427 for i in 0..u8::MAX {
428 let size = i.serialized_length();
429 assert_eq!(size, 1);
430
431 let mut v = vec![];
433 i.write_to(&mut v).unwrap();
434 let mut slice = v.as_slice();
435
436 let expected = u8::read_from(&mut slice).unwrap();
437 assert_eq!(i, expected);
438 }
439 }
440
441 #[test]
442 fn it_implements_seq_ops_for_u8() {
443 let input = vec![0, 1, 2, 3, 4];
444 let mut buf = vec![];
445 u8::write_seq_to(&mut buf, input.iter()).unwrap();
446 assert_eq!(buf.len(), input.len());
447 assert_eq!(buf, input);
448
449 let exact_len =
451 u8::read_seq_from(&mut buf.clone().as_slice(), ReadSeqMode::Exactly(buf.len()))
452 .unwrap();
453 assert_eq!(exact_len.len(), buf.len());
454 assert_eq!(input, exact_len);
455
456 let exact_too_long = u8::read_seq_from(
458 &mut buf.clone().as_slice(),
459 ReadSeqMode::Exactly(buf.len() + 1),
460 );
461 assert!(exact_too_long.is_err());
462
463 let exact_first =
465 u8::read_seq_from(&mut buf.clone().as_slice(), ReadSeqMode::Exactly(1)).unwrap();
466 assert_eq!(exact_first, vec![0]);
467
468 let exact_none =
470 u8::read_seq_from(&mut buf.clone().as_slice(), ReadSeqMode::Exactly(0)).unwrap();
471 assert_eq!(exact_none, Vec::<u8>::new());
472
473 let at_most_all =
475 u8::read_seq_from(&mut buf.clone().as_slice(), ReadSeqMode::AtMost(buf.len())).unwrap();
476 assert_eq!(at_most_all, buf.clone());
477 let at_most_more = u8::read_seq_from(
481 &mut buf.clone().as_slice(),
482 ReadSeqMode::AtMost(buf.len() + 10),
483 )
484 .unwrap();
485 assert_eq!(at_most_more, buf.clone());
486
487 let at_most_less = u8::read_seq_from(
489 &mut buf.clone().as_slice(),
490 ReadSeqMode::AtMost(buf.len() - 1),
491 )
492 .unwrap();
493 let mut resized = buf.clone();
494 resized.resize(buf.len() - 1, 0);
495 assert_eq!(at_most_less, resized);
496
497 let until_end =
499 u8::read_seq_from(&mut buf.clone().as_slice(), ReadSeqMode::UntilEnd).unwrap();
500 assert_eq!(until_end, buf.clone());
501 }
502}