1use core::fmt;
11
12pub trait Sink {
18 fn put(&mut self, bytes: &[u8]);
19}
20
21impl Sink for blake3::Hasher {
22 fn put(&mut self, bytes: &[u8]) {
23 self.update(bytes);
24 }
25}
26
27impl Sink for Vec<u8> {
28 fn put(&mut self, bytes: &[u8]) {
29 self.extend_from_slice(bytes);
30 }
31}
32
33pub fn write_u8<S: Sink>(out: &mut S, n: u8) {
34 out.put(&[n]);
35}
36
37pub fn write_u16<S: Sink>(out: &mut S, n: u16) {
38 out.put(&n.to_le_bytes());
39}
40
41pub fn write_u32<S: Sink>(out: &mut S, n: u32) {
42 out.put(&n.to_le_bytes());
43}
44
45pub fn write_u64<S: Sink>(out: &mut S, n: u64) {
46 out.put(&n.to_le_bytes());
47}
48
49pub fn write_u128<S: Sink>(out: &mut S, n: u128) {
50 out.put(&n.to_le_bytes());
51}
52
53pub fn write_i8<S: Sink>(out: &mut S, n: i8) {
54 out.put(&n.to_le_bytes());
55}
56
57pub fn write_i16<S: Sink>(out: &mut S, n: i16) {
58 out.put(&n.to_le_bytes());
59}
60
61pub fn write_i32<S: Sink>(out: &mut S, n: i32) {
62 out.put(&n.to_le_bytes());
63}
64
65pub fn write_i64<S: Sink>(out: &mut S, n: i64) {
66 out.put(&n.to_le_bytes());
67}
68
69pub fn write_i128<S: Sink>(out: &mut S, n: i128) {
70 out.put(&n.to_le_bytes());
71}
72
73pub fn write_f32<S: Sink>(out: &mut S, n: f32) {
74 out.put(&n.to_le_bytes());
75}
76
77pub fn write_f64<S: Sink>(out: &mut S, n: f64) {
78 out.put(&n.to_le_bytes());
79}
80
81pub fn write_bool<S: Sink>(out: &mut S, b: bool) {
82 write_u8(out, u8::from(b));
83}
84
85pub fn write_str<S: Sink>(out: &mut S, s: &str) {
87 write_u32(out, s.len() as u32);
88 out.put(s.as_bytes());
89}
90
91pub fn write_bytes<S: Sink>(out: &mut S, b: &[u8]) {
93 write_u32(out, b.len() as u32);
94 out.put(b);
95}
96
97pub fn pad_to(out: &mut Vec<u8>, n: usize) {
102 while !out.len().is_multiple_of(n) {
103 out.push(0);
104 }
105}
106
107#[derive(Clone, Debug, PartialEq, Eq)]
114#[non_exhaustive]
115pub enum DecodeError {
116 UnexpectedEof { needed: usize, remaining: usize },
118 UnknownTag(u8),
120 InvalidBool(u8),
122 InvalidUtf8,
124 InvalidChar(u32),
126 LengthTooLarge { count: u64, remaining: usize },
129 DepthExceeded,
131 DuplicateKey,
133 DuplicateElement,
135 UnexpectedTag { expected: &'static str, got: u8 },
137 UnknownVariant(String),
139 WriterOnlyVariant(u32),
143 BadVariantIndex(u32),
147 Malformed(&'static str),
150 TrailingBytes(usize),
152}
153
154impl fmt::Display for DecodeError {
155 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
156 match self {
157 DecodeError::UnexpectedEof { needed, remaining } => {
158 write!(
159 f,
160 "unexpected end of input: need {needed}, have {remaining}"
161 )
162 }
163 DecodeError::UnknownTag(t) => write!(f, "unknown tag {t:#04x}"),
164 DecodeError::InvalidBool(b) => write!(f, "invalid bool byte {b:#04x}"),
165 DecodeError::InvalidUtf8 => write!(f, "invalid UTF-8 in string"),
166 DecodeError::InvalidChar(c) => write!(f, "invalid Unicode scalar {c:#x}"),
167 DecodeError::LengthTooLarge { count, remaining } => {
168 write!(f, "length {count} exceeds {remaining} bytes remaining")
169 }
170 DecodeError::DepthExceeded => write!(f, "maximum nesting depth exceeded"),
171 DecodeError::DuplicateKey => write!(f, "duplicate map key"),
172 DecodeError::DuplicateElement => write!(f, "duplicate set element"),
173 DecodeError::UnexpectedTag { expected, got } => {
174 write!(f, "expected {expected}, got tag {got:#04x}")
175 }
176 DecodeError::UnknownVariant(name) => write!(f, "unknown variant {name:?}"),
177 DecodeError::WriterOnlyVariant(i) => {
178 write!(
179 f,
180 "received enum variant {i} the reader schema does not have"
181 )
182 }
183 DecodeError::BadVariantIndex(i) => write!(f, "enum variant index {i} out of range"),
184 DecodeError::Malformed(what) => write!(f, "malformed value: {what}"),
185 DecodeError::TrailingBytes(n) => write!(f, "{n} trailing bytes after value"),
186 }
187 }
188}
189
190impl std::error::Error for DecodeError {}
191
192pub struct Reader<'a> {
198 buf: &'a [u8],
199 pos: usize,
200}
201
202impl<'a> Reader<'a> {
203 #[must_use]
204 pub fn new(buf: &'a [u8]) -> Self {
205 Reader { buf, pos: 0 }
206 }
207
208 #[must_use]
210 pub fn remaining(&self) -> usize {
211 self.buf.len() - self.pos
212 }
213
214 #[must_use]
216 pub fn position(&self) -> usize {
217 self.pos
218 }
219
220 fn take(&mut self, n: usize) -> Result<&'a [u8], DecodeError> {
221 if self.remaining() < n {
222 return Err(DecodeError::UnexpectedEof {
223 needed: n,
224 remaining: self.remaining(),
225 });
226 }
227 let slice = &self.buf[self.pos..self.pos + n];
228 self.pos += n;
229 Ok(slice)
230 }
231
232 pub fn read_slice(&mut self, n: usize) -> Result<&'a [u8], DecodeError> {
234 self.take(n)
235 }
236
237 pub fn read_u8(&mut self) -> Result<u8, DecodeError> {
238 Ok(self.take(1)?[0])
239 }
240
241 pub fn read_u16(&mut self) -> Result<u16, DecodeError> {
242 Ok(u16::from_le_bytes(self.take(2)?.try_into().unwrap()))
243 }
244
245 pub fn read_u32(&mut self) -> Result<u32, DecodeError> {
246 Ok(u32::from_le_bytes(self.take(4)?.try_into().unwrap()))
247 }
248
249 pub fn read_u64(&mut self) -> Result<u64, DecodeError> {
250 Ok(u64::from_le_bytes(self.take(8)?.try_into().unwrap()))
251 }
252
253 pub fn read_u128(&mut self) -> Result<u128, DecodeError> {
254 Ok(u128::from_le_bytes(self.take(16)?.try_into().unwrap()))
255 }
256
257 pub fn read_i8(&mut self) -> Result<i8, DecodeError> {
258 Ok(i8::from_le_bytes(self.take(1)?.try_into().unwrap()))
259 }
260
261 pub fn read_i16(&mut self) -> Result<i16, DecodeError> {
262 Ok(i16::from_le_bytes(self.take(2)?.try_into().unwrap()))
263 }
264
265 pub fn read_i32(&mut self) -> Result<i32, DecodeError> {
266 Ok(i32::from_le_bytes(self.take(4)?.try_into().unwrap()))
267 }
268
269 pub fn read_i64(&mut self) -> Result<i64, DecodeError> {
270 Ok(i64::from_le_bytes(self.take(8)?.try_into().unwrap()))
271 }
272
273 pub fn read_i128(&mut self) -> Result<i128, DecodeError> {
274 Ok(i128::from_le_bytes(self.take(16)?.try_into().unwrap()))
275 }
276
277 pub fn read_f32(&mut self) -> Result<f32, DecodeError> {
278 Ok(f32::from_le_bytes(self.take(4)?.try_into().unwrap()))
279 }
280
281 pub fn read_f64(&mut self) -> Result<f64, DecodeError> {
282 Ok(f64::from_le_bytes(self.take(8)?.try_into().unwrap()))
283 }
284
285 pub fn read_bool(&mut self) -> Result<bool, DecodeError> {
286 match self.read_u8()? {
287 0 => Ok(false),
288 1 => Ok(true),
289 b => Err(DecodeError::InvalidBool(b)),
290 }
291 }
292
293 pub fn read_char(&mut self) -> Result<char, DecodeError> {
296 let n = self.read_u32()?;
297 char::from_u32(n).ok_or(DecodeError::InvalidChar(n))
298 }
299
300 pub fn read_str(&mut self) -> Result<&'a str, DecodeError> {
303 let len = self.read_len(1)?;
304 let bytes = self.take(len)?;
305 core::str::from_utf8(bytes).map_err(|_| DecodeError::InvalidUtf8)
306 }
307
308 pub fn read_bytes(&mut self) -> Result<&'a [u8], DecodeError> {
310 let len = self.read_len(1)?;
311 self.take(len)
312 }
313
314 pub fn read_len(&mut self, min_elem_size: usize) -> Result<usize, DecodeError> {
329 let count = self.read_u32()? as usize;
330 let remaining = self.remaining();
331 let max = remaining
335 .checked_div(min_elem_size)
336 .unwrap_or(ZST_COUNT_CAP);
337 if count > max {
338 return Err(DecodeError::LengthTooLarge {
339 count: count as u64,
340 remaining,
341 });
342 }
343 Ok(count)
344 }
345}
346
347pub const ZST_COUNT_CAP: usize = 1 << 24;
353
354pub fn skip_pad(r: &mut Reader, n: usize) -> Result<(), DecodeError> {
360 while !r.position().is_multiple_of(n) {
361 r.read_u8()?;
362 }
363 Ok(())
364}