pinenut_log/
codec.rs

1//! Encoding & Decoding.
2
3use std::{any::type_name, slice, str};
4
5use thiserror::Error;
6
7use crate::{
8    common::{BytesBuf, FnSink},
9    DateTime, Level,
10};
11
12/// Errors that can be occurred by encoding a type.
13#[non_exhaustive]
14#[derive(Error, Clone, Debug)]
15pub enum EncodingError {
16    /// No errors yet.
17    #[allow(dead_code)]
18    #[error("unreachable")]
19    None,
20}
21
22/// Represents a target for encoded data.
23pub(crate) trait Sink = crate::Sink<EncodingError>;
24
25/// Any data type that can be encoded.
26///
27/// `Pinenut` encodes the data into a stream of compact binary bytes and outputs to
28/// the `Sink`.
29///
30/// This trait will be automatically implemented if you add `#[derive(Encode)]` to a
31/// struct.
32pub(crate) trait Encode {
33    /// Encode the data and write encoded bytes to `Sink`.
34    fn encode<S>(&self, sink: &mut S) -> Result<(), S::Error>
35    where
36        S: Sink;
37}
38
39/// Errors that can be occurred by decoding a type.
40#[derive(Error, Clone, Debug)]
41#[non_exhaustive]
42pub enum DecodingError {
43    /// The source reached its end but more bytes were expected.
44    #[error("the source reached its end, but more bytes ({extra_len}) were expected")]
45    UnexpectedEnd {
46        /// How many extra bytes are needed.
47        extra_len: usize,
48    },
49    /// Invalid variant was found. This error is generally for enums.
50    #[error("invalid variant ({found_byte}) was found on type `{type_name}`")]
51    UnexpectedVariant {
52        /// The type name that was being decoded.
53        type_name: &'static str,
54        /// The byte that has been read.
55        found_byte: u8,
56    },
57    /// Which can be occurred when attempting to decode bytes as a `str`, it is
58    /// essentially an UTF-8 error.
59    #[error(transparent)]
60    Str(#[from] str::Utf8Error),
61    /// The encoded varint is outside of the range of the target integral type.
62    ///
63    /// This may happen if an usize was encoded on 64-bit architecture and then
64    /// decoded on 32-bit architecture (from large type to small type).
65    #[error("the encoded varint is outside of the range of the target integral type")]
66    IntegerOverflow,
67    /// Which can be occurred on out-of-range number of seconds and/or invalid
68    /// nanosecond.
69    #[error("failed to decode date & time")]
70    DateTime,
71}
72
73/// Represents a provider for encoded data.
74pub(crate) trait Source<'de> {
75    type Error: From<DecodingError>;
76
77    /// Take a length and attempt to read that many bytes.
78    fn read_bytes(&mut self, len: usize) -> Result<&'de [u8], Self::Error>;
79}
80
81/// Any data type that can be decoded.
82///
83/// `Pinenut` decodes the data by continuously reading a stream of compact binary
84/// bytes from the `Source`.
85///
86/// The `'de` lifetime is what enables `Pinenut` to safely perform efficient
87/// zero-copy decoding across a variety of data formats.
88///
89/// This trait will be automatically implemented if you add `#[derive(Decode)]` to a
90/// struct.
91pub(crate) trait Decode<'de>: Sized {
92    /// Decode the data from `Source`.
93    fn decode<S>(source: &mut S) -> Result<Self, S::Error>
94    where
95        S: Source<'de>;
96}
97
98/// Used to accumulate the data generated during encoding and reduce the callback
99/// frequency.
100///
101/// In order to reduce the frequency of calling the Sink, a buffer is used to
102/// temporarily store data. When the buffer is full, it will be flushed to the Sink,
103/// otherwise it will continue to wait for the buffer to be filled.
104pub(crate) struct AccumulationEncoder {
105    buffer: BytesBuf,
106}
107
108impl AccumulationEncoder {
109    /// Constructs a new `AccumulationEncoder`.
110    #[inline]
111    pub(crate) fn new(buffer_len: usize) -> Self {
112        Self { buffer: BytesBuf::with_capacity(buffer_len) }
113    }
114
115    /// Encode the data and write encoded bytes to `Sink`.
116    pub(crate) fn encode<T, S>(&mut self, value: &T, sink: &mut S) -> Result<(), S::Error>
117    where
118        T: Encode,
119        S: Sink,
120    {
121        value.encode(&mut FnSink::new(|mut bytes: &[u8]| {
122            loop {
123                let buffered = self.buffer.buffer(bytes);
124                bytes = &bytes[buffered..];
125
126                // The buffer is not full.
127                if bytes.is_empty() {
128                    break Ok(());
129                }
130
131                // The buffer is full, flushes it into Sink.
132                let result = sink.sink(&self.buffer);
133                if result.is_err() {
134                    break result;
135                }
136
137                // Keeps waiting for the data to fill in.
138                self.buffer.clear();
139            }
140        }))?;
141
142        // Flushes the buffer into the Sink.
143        sink.sink(&self.buffer)?;
144        self.buffer.clear();
145
146        Ok(())
147    }
148}
149
150// ============ Implementations ============
151
152impl Encode for u8 {
153    #[inline]
154    fn encode<S>(&self, sink: &mut S) -> Result<(), S::Error>
155    where
156        S: Sink,
157    {
158        sink.sink(slice::from_ref(self))
159    }
160}
161
162impl<'de> Decode<'de> for u8 {
163    #[inline]
164    fn decode<S>(source: &mut S) -> Result<Self, S::Error>
165    where
166        S: Source<'de>,
167    {
168        let bytes = source.read_bytes(1)?;
169        // `source` is responsible for errors handling, so use `unwrap` directly here.
170        Ok(*bytes.first().unwrap())
171    }
172}
173
174/// Implements `Encode` and `Decode` traits for specified integral type, using
175/// `varint` (variable length integer) encoding.
176///
177/// Currently, encoding negative integers is not supported. `ZigZag` encoding may be
178/// used in the future.
179macro_rules! integral_type_codec_impl {
180    ($Self:ty) => {
181        integral_type_codec_impl!(encode: $Self);
182        integral_type_codec_impl!(decode: $Self);
183    };
184
185    (encode: $Self:ty) => {
186        impl Encode for $Self {
187            fn encode<S>(&self, sink: &mut S) -> Result<(), S::Error>
188            where
189                S: Sink,
190            {
191                let mut val = *self;
192                loop {
193                    if val <= 0x7F {
194                        (val as u8).encode(sink)?;
195                        break Ok(());
196                    }
197                    ((val & 0x7F) as u8 | 0x80).encode(sink)?;
198                    val >>= 7;
199                }
200            }
201        }
202    };
203
204    (decode: $Self:ty) => {
205        impl<'de> Decode<'de> for $Self {
206            fn decode<S>(source: &mut S) -> Result<Self, S::Error>
207            where
208                S: Source<'de>,
209            {
210                let (mut val, mut shift) = (0, 0);
211                loop {
212                    let byte = u8::decode(source)?;
213                    let high_bits = byte as $Self & 0x7F;
214                    // Check for overflow.
215                    if high_bits.leading_zeros() < shift {
216                        break Err(DecodingError::IntegerOverflow.into());
217                    }
218                    val |= high_bits << shift;
219                    if byte & 0x80 == 0 {
220                        break Ok(val);
221                    }
222                    shift += 7;
223                }
224            }
225        }
226    };
227}
228
229integral_type_codec_impl!(u32);
230integral_type_codec_impl!(u64);
231integral_type_codec_impl!(usize);
232
233impl<const N: usize> Encode for &[u8; N] {
234    #[inline]
235    fn encode<S>(&self, sink: &mut S) -> Result<(), S::Error>
236    where
237        S: Sink,
238    {
239        self.as_slice().encode(sink)
240    }
241}
242
243impl<'de: 'a, 'a, const N: usize> Decode<'de> for &'a [u8; N] {
244    #[inline]
245    fn decode<S>(source: &mut S) -> Result<Self, S::Error>
246    where
247        S: Source<'de>,
248    {
249        let bytes = source.read_bytes(N)?;
250        // `source` is responsible for errors handling, so use `unwrap` directly here.
251        Ok(bytes.try_into().unwrap())
252    }
253}
254
255impl Encode for &[u8] {
256    #[inline]
257    fn encode<S>(&self, sink: &mut S) -> Result<(), S::Error>
258    where
259        S: Sink,
260    {
261        // Encode the length first, then the payload.
262        self.len().encode(sink)?;
263        sink.sink(self)
264    }
265}
266
267impl<'de: 'a, 'a> Decode<'de> for &'a [u8] {
268    #[inline]
269    fn decode<S>(source: &mut S) -> Result<Self, S::Error>
270    where
271        S: Source<'de>,
272    {
273        // Decode the length first, then read bytes of length.
274        let len = usize::decode(source)?;
275        source.read_bytes(len)
276    }
277}
278
279// `&[u8]` is also a `Source`.
280impl<'a> Source<'a> for &'a [u8] {
281    type Error = DecodingError;
282
283    fn read_bytes(&mut self, len: usize) -> Result<&'a [u8], Self::Error> {
284        if self.len() >= len {
285            let (bytes, remaining) = self.split_at(len);
286            *self = remaining;
287            Ok(bytes)
288        } else {
289            Err(DecodingError::UnexpectedEnd { extra_len: len - self.len() })
290        }
291    }
292}
293
294impl Encode for &str {
295    #[inline]
296    fn encode<S>(&self, sink: &mut S) -> Result<(), S::Error>
297    where
298        S: Sink,
299    {
300        self.as_bytes().encode(sink)
301    }
302}
303
304impl<'de: 'a, 'a> Decode<'de> for &'a str {
305    #[inline]
306    fn decode<S>(source: &mut S) -> Result<Self, S::Error>
307    where
308        S: Source<'de>,
309    {
310        let bytes = Decode::decode(source)?;
311        str::from_utf8(bytes).map_err(|e| DecodingError::Str(e).into())
312    }
313}
314
315const OPTION_NONE_TAG: u8 = 0;
316const OPTION_SOME_TAG: u8 = 1;
317
318impl<T> Encode for Option<T>
319where
320    T: Encode,
321{
322    #[inline]
323    fn encode<S>(&self, sink: &mut S) -> Result<(), S::Error>
324    where
325        S: Sink,
326    {
327        // Encode the tag first, then the payload if there is one.
328        match self {
329            None => OPTION_NONE_TAG.encode(sink),
330            Some(inner) => {
331                OPTION_SOME_TAG.encode(sink)?;
332                inner.encode(sink)
333            }
334        }
335    }
336}
337
338impl<'de, T> Decode<'de> for Option<T>
339where
340    T: Decode<'de>,
341{
342    fn decode<S>(source: &mut S) -> Result<Self, S::Error>
343    where
344        S: Source<'de>,
345    {
346        // Decode the tag first, then the payload if the tag is `Some`.
347        let tag = Decode::decode(source)?;
348        match tag {
349            OPTION_NONE_TAG => Ok(None),
350            OPTION_SOME_TAG => Decode::decode(source).map(Some),
351            _ => Err(DecodingError::UnexpectedVariant {
352                type_name: type_name::<Self>(),
353                found_byte: tag,
354            }
355            .into()),
356        }
357    }
358}
359
360impl Encode for Level {
361    #[inline]
362    fn encode<S>(&self, sink: &mut S) -> Result<(), S::Error>
363    where
364        S: Sink,
365    {
366        self.primitive().encode(sink)
367    }
368}
369
370impl<'de> Decode<'de> for Level {
371    #[inline]
372    fn decode<S>(source: &mut S) -> Result<Self, S::Error>
373    where
374        S: Source<'de>,
375    {
376        let primitive = Decode::decode(source)?;
377        if let Some(level) = Level::from_primitive(primitive) {
378            Ok(level)
379        } else {
380            Err(DecodingError::UnexpectedVariant {
381                type_name: type_name::<Self>(),
382                found_byte: primitive,
383            }
384            .into())
385        }
386    }
387}
388
389impl Encode for DateTime {
390    #[inline]
391    fn encode<S>(&self, sink: &mut S) -> Result<(), S::Error>
392    where
393        S: Sink,
394    {
395        // Encode `secs`. It can't be earlier than the midnight on January 1, 1970.
396        self.timestamp().try_into().unwrap_or(0u64).encode(sink)?;
397        // Encode `nsecs`.
398        self.timestamp_subsec_nanos().encode(sink)
399    }
400}
401
402impl<'de> Decode<'de> for DateTime {
403    #[inline]
404    fn decode<S>(source: &mut S) -> Result<Self, S::Error>
405    where
406        S: Source<'de>,
407    {
408        // Decode `secs`.
409        let secs = u64::decode(source)?.try_into().map_err(|_| DecodingError::IntegerOverflow)?;
410        // Decode `nsecs`.
411        let nsecs = u32::decode(source)?;
412        // Make date & time.
413        DateTime::from_timestamp(secs, nsecs).ok_or(DecodingError::DateTime.into())
414    }
415}
416
417#[cfg(test)]
418mod tests {
419    use crate::{
420        codec::{Decode, DecodingError, Encode},
421        DateTime,
422    };
423
424    /// Codec testing helper.
425    ///
426    /// It takes two arguments (type, value) and returns the encoded bytes.
427    macro_rules! test_coding {
428        ($ty:ty, $val:expr) => {{
429            let mut sink = Vec::new();
430
431            let val: $ty = $val;
432            val.encode(&mut sink).unwrap();
433
434            let mut source = sink.as_slice();
435            assert_eq!(<$ty>::decode(&mut source).unwrap(), $val);
436            assert!(source.is_empty());
437
438            sink
439        }};
440    }
441
442    #[test]
443    fn test_integer() {
444        assert_eq!(test_coding!(u32, 0x7F), [0x7F]);
445        assert_eq!(test_coding!(u64, 0x80), [0x80, 0x01]);
446        assert_eq!(test_coding!(u64, 0xC0C0C0C0), [0xC0, 0x81, 0x83, 0x86, 0x0C]);
447        // Test for overflow.
448        let sink = test_coding!(u64, u32::MAX as u64 + 1);
449        assert_eq!(sink, [0x80, 0x80, 0x80, 0x80, 0x10]);
450        let mut source = sink.as_slice();
451        assert!(matches!(u32::decode(&mut source), Err(DecodingError::IntegerOverflow)));
452    }
453
454    #[test]
455    fn test_option() {
456        assert_eq!(test_coding!(Option<u8>, None), [0x00]);
457        assert_eq!(test_coding!(Option<u8>, Some(0xFF)), [0x01, 0xFF]);
458    }
459
460    #[test]
461    fn test_str() {
462        assert_eq!(test_coding!(&str, ""), [0x00]);
463        assert_eq!(
464            test_coding!(&str, "Hello World"),
465            [0x0B, 0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x20, 0x57, 0x6F, 0x72, 0x6C, 0x64]
466        );
467    }
468
469    #[test]
470    fn test_datetime() {
471        let datetime = chrono::Utc::now();
472        test_coding!(DateTime, datetime);
473    }
474}