1use std::{any::type_name, slice, str};
4
5use thiserror::Error;
6
7use crate::{
8 common::{BytesBuf, FnSink},
9 DateTime, Level,
10};
11
12#[non_exhaustive]
14#[derive(Error, Clone, Debug)]
15pub enum EncodingError {
16 #[allow(dead_code)]
18 #[error("unreachable")]
19 None,
20}
21
22pub(crate) trait Sink = crate::Sink<EncodingError>;
24
25pub(crate) trait Encode {
33 fn encode<S>(&self, sink: &mut S) -> Result<(), S::Error>
35 where
36 S: Sink;
37}
38
39#[derive(Error, Clone, Debug)]
41#[non_exhaustive]
42pub enum DecodingError {
43 #[error("the source reached its end, but more bytes ({extra_len}) were expected")]
45 UnexpectedEnd {
46 extra_len: usize,
48 },
49 #[error("invalid variant ({found_byte}) was found on type `{type_name}`")]
51 UnexpectedVariant {
52 type_name: &'static str,
54 found_byte: u8,
56 },
57 #[error(transparent)]
60 Str(#[from] str::Utf8Error),
61 #[error("the encoded varint is outside of the range of the target integral type")]
66 IntegerOverflow,
67 #[error("failed to decode date & time")]
70 DateTime,
71}
72
73pub(crate) trait Source<'de> {
75 type Error: From<DecodingError>;
76
77 fn read_bytes(&mut self, len: usize) -> Result<&'de [u8], Self::Error>;
79}
80
81pub(crate) trait Decode<'de>: Sized {
92 fn decode<S>(source: &mut S) -> Result<Self, S::Error>
94 where
95 S: Source<'de>;
96}
97
98pub(crate) struct AccumulationEncoder {
105 buffer: BytesBuf,
106}
107
108impl AccumulationEncoder {
109 #[inline]
111 pub(crate) fn new(buffer_len: usize) -> Self {
112 Self { buffer: BytesBuf::with_capacity(buffer_len) }
113 }
114
115 pub(crate) fn encode<T, S>(&mut self, value: &T, sink: &mut S) -> Result<(), S::Error>
117 where
118 T: Encode,
119 S: Sink,
120 {
121 value.encode(&mut FnSink::new(|mut bytes: &[u8]| {
122 loop {
123 let buffered = self.buffer.buffer(bytes);
124 bytes = &bytes[buffered..];
125
126 if bytes.is_empty() {
128 break Ok(());
129 }
130
131 let result = sink.sink(&self.buffer);
133 if result.is_err() {
134 break result;
135 }
136
137 self.buffer.clear();
139 }
140 }))?;
141
142 sink.sink(&self.buffer)?;
144 self.buffer.clear();
145
146 Ok(())
147 }
148}
149
150impl Encode for u8 {
153 #[inline]
154 fn encode<S>(&self, sink: &mut S) -> Result<(), S::Error>
155 where
156 S: Sink,
157 {
158 sink.sink(slice::from_ref(self))
159 }
160}
161
162impl<'de> Decode<'de> for u8 {
163 #[inline]
164 fn decode<S>(source: &mut S) -> Result<Self, S::Error>
165 where
166 S: Source<'de>,
167 {
168 let bytes = source.read_bytes(1)?;
169 Ok(*bytes.first().unwrap())
171 }
172}
173
174macro_rules! integral_type_codec_impl {
180 ($Self:ty) => {
181 integral_type_codec_impl!(encode: $Self);
182 integral_type_codec_impl!(decode: $Self);
183 };
184
185 (encode: $Self:ty) => {
186 impl Encode for $Self {
187 fn encode<S>(&self, sink: &mut S) -> Result<(), S::Error>
188 where
189 S: Sink,
190 {
191 let mut val = *self;
192 loop {
193 if val <= 0x7F {
194 (val as u8).encode(sink)?;
195 break Ok(());
196 }
197 ((val & 0x7F) as u8 | 0x80).encode(sink)?;
198 val >>= 7;
199 }
200 }
201 }
202 };
203
204 (decode: $Self:ty) => {
205 impl<'de> Decode<'de> for $Self {
206 fn decode<S>(source: &mut S) -> Result<Self, S::Error>
207 where
208 S: Source<'de>,
209 {
210 let (mut val, mut shift) = (0, 0);
211 loop {
212 let byte = u8::decode(source)?;
213 let high_bits = byte as $Self & 0x7F;
214 if high_bits.leading_zeros() < shift {
216 break Err(DecodingError::IntegerOverflow.into());
217 }
218 val |= high_bits << shift;
219 if byte & 0x80 == 0 {
220 break Ok(val);
221 }
222 shift += 7;
223 }
224 }
225 }
226 };
227}
228
229integral_type_codec_impl!(u32);
230integral_type_codec_impl!(u64);
231integral_type_codec_impl!(usize);
232
233impl<const N: usize> Encode for &[u8; N] {
234 #[inline]
235 fn encode<S>(&self, sink: &mut S) -> Result<(), S::Error>
236 where
237 S: Sink,
238 {
239 self.as_slice().encode(sink)
240 }
241}
242
243impl<'de: 'a, 'a, const N: usize> Decode<'de> for &'a [u8; N] {
244 #[inline]
245 fn decode<S>(source: &mut S) -> Result<Self, S::Error>
246 where
247 S: Source<'de>,
248 {
249 let bytes = source.read_bytes(N)?;
250 Ok(bytes.try_into().unwrap())
252 }
253}
254
255impl Encode for &[u8] {
256 #[inline]
257 fn encode<S>(&self, sink: &mut S) -> Result<(), S::Error>
258 where
259 S: Sink,
260 {
261 self.len().encode(sink)?;
263 sink.sink(self)
264 }
265}
266
267impl<'de: 'a, 'a> Decode<'de> for &'a [u8] {
268 #[inline]
269 fn decode<S>(source: &mut S) -> Result<Self, S::Error>
270 where
271 S: Source<'de>,
272 {
273 let len = usize::decode(source)?;
275 source.read_bytes(len)
276 }
277}
278
279impl<'a> Source<'a> for &'a [u8] {
281 type Error = DecodingError;
282
283 fn read_bytes(&mut self, len: usize) -> Result<&'a [u8], Self::Error> {
284 if self.len() >= len {
285 let (bytes, remaining) = self.split_at(len);
286 *self = remaining;
287 Ok(bytes)
288 } else {
289 Err(DecodingError::UnexpectedEnd { extra_len: len - self.len() })
290 }
291 }
292}
293
294impl Encode for &str {
295 #[inline]
296 fn encode<S>(&self, sink: &mut S) -> Result<(), S::Error>
297 where
298 S: Sink,
299 {
300 self.as_bytes().encode(sink)
301 }
302}
303
304impl<'de: 'a, 'a> Decode<'de> for &'a str {
305 #[inline]
306 fn decode<S>(source: &mut S) -> Result<Self, S::Error>
307 where
308 S: Source<'de>,
309 {
310 let bytes = Decode::decode(source)?;
311 str::from_utf8(bytes).map_err(|e| DecodingError::Str(e).into())
312 }
313}
314
315const OPTION_NONE_TAG: u8 = 0;
316const OPTION_SOME_TAG: u8 = 1;
317
318impl<T> Encode for Option<T>
319where
320 T: Encode,
321{
322 #[inline]
323 fn encode<S>(&self, sink: &mut S) -> Result<(), S::Error>
324 where
325 S: Sink,
326 {
327 match self {
329 None => OPTION_NONE_TAG.encode(sink),
330 Some(inner) => {
331 OPTION_SOME_TAG.encode(sink)?;
332 inner.encode(sink)
333 }
334 }
335 }
336}
337
338impl<'de, T> Decode<'de> for Option<T>
339where
340 T: Decode<'de>,
341{
342 fn decode<S>(source: &mut S) -> Result<Self, S::Error>
343 where
344 S: Source<'de>,
345 {
346 let tag = Decode::decode(source)?;
348 match tag {
349 OPTION_NONE_TAG => Ok(None),
350 OPTION_SOME_TAG => Decode::decode(source).map(Some),
351 _ => Err(DecodingError::UnexpectedVariant {
352 type_name: type_name::<Self>(),
353 found_byte: tag,
354 }
355 .into()),
356 }
357 }
358}
359
360impl Encode for Level {
361 #[inline]
362 fn encode<S>(&self, sink: &mut S) -> Result<(), S::Error>
363 where
364 S: Sink,
365 {
366 self.primitive().encode(sink)
367 }
368}
369
370impl<'de> Decode<'de> for Level {
371 #[inline]
372 fn decode<S>(source: &mut S) -> Result<Self, S::Error>
373 where
374 S: Source<'de>,
375 {
376 let primitive = Decode::decode(source)?;
377 if let Some(level) = Level::from_primitive(primitive) {
378 Ok(level)
379 } else {
380 Err(DecodingError::UnexpectedVariant {
381 type_name: type_name::<Self>(),
382 found_byte: primitive,
383 }
384 .into())
385 }
386 }
387}
388
389impl Encode for DateTime {
390 #[inline]
391 fn encode<S>(&self, sink: &mut S) -> Result<(), S::Error>
392 where
393 S: Sink,
394 {
395 self.timestamp().try_into().unwrap_or(0u64).encode(sink)?;
397 self.timestamp_subsec_nanos().encode(sink)
399 }
400}
401
402impl<'de> Decode<'de> for DateTime {
403 #[inline]
404 fn decode<S>(source: &mut S) -> Result<Self, S::Error>
405 where
406 S: Source<'de>,
407 {
408 let secs = u64::decode(source)?.try_into().map_err(|_| DecodingError::IntegerOverflow)?;
410 let nsecs = u32::decode(source)?;
412 DateTime::from_timestamp(secs, nsecs).ok_or(DecodingError::DateTime.into())
414 }
415}
416
417#[cfg(test)]
418mod tests {
419 use crate::{
420 codec::{Decode, DecodingError, Encode},
421 DateTime,
422 };
423
424 macro_rules! test_coding {
428 ($ty:ty, $val:expr) => {{
429 let mut sink = Vec::new();
430
431 let val: $ty = $val;
432 val.encode(&mut sink).unwrap();
433
434 let mut source = sink.as_slice();
435 assert_eq!(<$ty>::decode(&mut source).unwrap(), $val);
436 assert!(source.is_empty());
437
438 sink
439 }};
440 }
441
442 #[test]
443 fn test_integer() {
444 assert_eq!(test_coding!(u32, 0x7F), [0x7F]);
445 assert_eq!(test_coding!(u64, 0x80), [0x80, 0x01]);
446 assert_eq!(test_coding!(u64, 0xC0C0C0C0), [0xC0, 0x81, 0x83, 0x86, 0x0C]);
447 let sink = test_coding!(u64, u32::MAX as u64 + 1);
449 assert_eq!(sink, [0x80, 0x80, 0x80, 0x80, 0x10]);
450 let mut source = sink.as_slice();
451 assert!(matches!(u32::decode(&mut source), Err(DecodingError::IntegerOverflow)));
452 }
453
454 #[test]
455 fn test_option() {
456 assert_eq!(test_coding!(Option<u8>, None), [0x00]);
457 assert_eq!(test_coding!(Option<u8>, Some(0xFF)), [0x01, 0xFF]);
458 }
459
460 #[test]
461 fn test_str() {
462 assert_eq!(test_coding!(&str, ""), [0x00]);
463 assert_eq!(
464 test_coding!(&str, "Hello World"),
465 [0x0B, 0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x20, 0x57, 0x6F, 0x72, 0x6C, 0x64]
466 );
467 }
468
469 #[test]
470 fn test_datetime() {
471 let datetime = chrono::Utc::now();
472 test_coding!(DateTime, datetime);
473 }
474}