Skip to main content

desert_core/
binary_input.rs

1use std::io::Read;
2
3use flate2::read::DeflateDecoder;
4
5use crate::error::Result;
6use crate::Error;
7
8pub trait BinaryInput {
9    fn read_u8(&mut self) -> Result<u8>;
10    fn read_bytes(&mut self, count: usize) -> Result<&[u8]>;
11    fn skip(&mut self, count: usize) -> Result<()>;
12
13    fn read_i8(&mut self) -> Result<i8> {
14        Ok(self.read_u8()? as i8)
15    }
16
17    fn read_u16(&mut self) -> Result<u16> {
18        let bytes = self.read_bytes(2)?;
19        Ok(u16::from_be_bytes(bytes.try_into()?))
20    }
21
22    fn read_i16(&mut self) -> Result<i16> {
23        let bytes = self.read_bytes(2)?;
24        Ok(i16::from_be_bytes(bytes.try_into()?))
25    }
26
27    fn read_u32(&mut self) -> Result<u32> {
28        let bytes = self.read_bytes(4)?;
29        Ok(u32::from_be_bytes(bytes.try_into()?))
30    }
31
32    fn read_i32(&mut self) -> Result<i32> {
33        let bytes = self.read_bytes(4)?;
34        Ok(i32::from_be_bytes(bytes.try_into()?))
35    }
36
37    fn read_u64(&mut self) -> Result<u64> {
38        let bytes = self.read_bytes(8)?;
39        Ok(u64::from_be_bytes(bytes.try_into()?))
40    }
41
42    fn read_i64(&mut self) -> Result<i64> {
43        let bytes = self.read_bytes(8)?;
44        Ok(i64::from_be_bytes(bytes.try_into()?))
45    }
46
47    fn read_u128(&mut self) -> Result<u128> {
48        let bytes = self.read_bytes(16)?;
49        Ok(u128::from_be_bytes(bytes.try_into()?))
50    }
51
52    fn read_i128(&mut self) -> Result<i128> {
53        let bytes = self.read_bytes(16)?;
54        Ok(i128::from_be_bytes(bytes.try_into()?))
55    }
56
57    fn read_f32(&mut self) -> Result<f32> {
58        let bytes = self.read_bytes(4)?;
59        Ok(f32::from_be_bytes(bytes.try_into()?))
60    }
61
62    fn read_f64(&mut self) -> Result<f64> {
63        let bytes = self.read_bytes(8)?;
64        Ok(f64::from_be_bytes(bytes.try_into()?))
65    }
66
67    fn read_var_u32(&mut self) -> Result<u32> {
68        let mut result: u32 = 0;
69        let mut shift = 0;
70
71        loop {
72            let b = self.read_u8()?;
73            result |= ((b & 0x7F) as u32) << shift;
74            if b & 0x80 == 0 {
75                break;
76            }
77            shift += 7;
78            if shift >= 32 {
79                return Err(Error::DeserializationFailure(
80                    "var_u32 too long".to_string(),
81                ));
82            }
83        }
84
85        Ok(result)
86    }
87
88    fn read_var_i32(&mut self) -> Result<i32> {
89        let r = self.read_var_u32()?;
90        Ok(((r >> 1) ^ (-((r & 1) as i32) as u32)) as i32)
91    }
92
93    fn read_compressed(&mut self) -> Result<Vec<u8>> {
94        let uncompressed_len = self.read_var_u32()? as usize;
95        let compressed_len = self.read_var_u32()? as usize;
96        let compressed = self.read_bytes(compressed_len)?;
97        let mut deflater = DeflateDecoder::new(compressed);
98        let mut result = Vec::with_capacity(uncompressed_len);
99        deflater
100            .read_to_end(&mut result)
101            .map_err(|err| Error::DecompressionFailure(format!("{err}")))?;
102        Ok(result)
103    }
104}
105
106pub struct SliceInput<'a> {
107    pub data: &'a [u8],
108    pub pos: usize,
109}
110
111impl<'a> SliceInput<'a> {
112    pub fn new(data: &'a [u8]) -> Self {
113        Self { data, pos: 0 }
114    }
115
116    pub const EMPTY: Self = Self { data: &[], pos: 0 };
117}
118
119impl BinaryInput for SliceInput<'_> {
120    fn read_u8(&mut self) -> Result<u8> {
121        if self.pos == self.data.len() {
122            Err(Error::InputEndedUnexpectedly)
123        } else {
124            let result = self.data[self.pos];
125            self.pos += 1;
126            Ok(result)
127        }
128    }
129
130    fn read_bytes(&mut self, count: usize) -> Result<&[u8]> {
131        if self.pos + count > self.data.len() {
132            Err(Error::InputEndedUnexpectedly)
133        } else {
134            let result = &self.data[self.pos..self.pos + count];
135            self.pos += count;
136            Ok(result)
137        }
138    }
139
140    fn skip(&mut self, count: usize) -> Result<()> {
141        if self.pos + count > self.data.len() {
142            Err(Error::InputEndedUnexpectedly)
143        } else {
144            self.pos += count;
145            Ok(())
146        }
147    }
148}
149
150pub struct OwnedInput {
151    data: Vec<u8>,
152    pos: usize,
153}
154
155impl OwnedInput {
156    pub fn new(data: Vec<u8>) -> Self {
157        Self { data, pos: 0 }
158    }
159}
160
161impl BinaryInput for OwnedInput {
162    fn read_u8(&mut self) -> Result<u8> {
163        if self.pos == self.data.len() {
164            Err(Error::InputEndedUnexpectedly)
165        } else {
166            let result = self.data[self.pos];
167            self.pos += 1;
168            Ok(result)
169        }
170    }
171
172    fn read_bytes(&mut self, count: usize) -> Result<&[u8]> {
173        if self.pos + count > self.data.len() {
174            Err(Error::InputEndedUnexpectedly)
175        } else {
176            let result = &self.data[self.pos..self.pos + count];
177            self.pos += count;
178            Ok(result)
179        }
180    }
181
182    fn skip(&mut self, count: usize) -> Result<()> {
183        if self.pos + count > self.data.len() {
184            Err(Error::InputEndedUnexpectedly)
185        } else {
186            self.pos += count;
187            Ok(())
188        }
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use crate::binary_input::OwnedInput;
195    use crate::{BinaryInput, BinaryOutput};
196    use bytes::BytesMut;
197    use proptest::prelude::*;
198    use test_r::test;
199
200    proptest! {
201        #[test]
202        fn roundtrip_var_i32(value: i32) {
203            let mut bytes = BytesMut::new();
204            bytes.write_var_i32(value);
205
206            let mut bytes = OwnedInput::new(bytes.freeze().to_vec());
207            let result = bytes.read_var_i32().unwrap();
208            assert_eq!(value, result);
209        }
210
211        #[test]
212        fn roundtrip_var_u32(value: u32) {
213            let mut bytes = BytesMut::new();
214            bytes.write_var_u32(value);
215
216            let mut bytes = OwnedInput::new(bytes.freeze().to_vec());
217            let result = bytes.read_var_u32().unwrap();
218            assert_eq!(value, result);
219        }
220
221        #[test]
222        fn roundtrip_compressed(bytes: Vec<u8>) {
223            let mut compressed = BytesMut::new();
224            compressed.write_compressed(&bytes, Default::default()).unwrap();
225
226            let mut compressed = OwnedInput::new(compressed.freeze().to_vec());
227            let result = compressed.read_compressed().unwrap();
228            assert_eq!(bytes, result);
229        }
230    }
231
232    #[test]
233    fn roundtrip_slice() -> Result<(), crate::Error> {
234        let data: [u8; 11] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
235
236        let mut bytes = BytesMut::new();
237        bytes.write_bytes(&data[2..6]);
238
239        let mut bytes = OwnedInput::new(bytes.freeze().to_vec());
240        let result = bytes.read_bytes(4)?;
241
242        assert_eq!(result, &[2, 3, 4, 5]);
243        Ok(())
244    }
245}