gel_protogen/
encoding.rs

1use std::mem::MaybeUninit;
2
3use crate::datatypes::*;
4use crate::prelude::*;
5use crate::{declare_type, encoder_for_array};
6use uuid::Uuid;
7
8/// All data types must implement this trait. This allows for encoding and
9/// decoding of the data type to byte buffers.
10pub trait DataType
11where
12    Self: Sized,
13{
14    const META: StructFieldMeta;
15
16    #[allow(unused)]
17    fn encode_usize(buf: &mut BufWriter<'_>, value: usize) {
18        unreachable!("encode usize")
19    }
20    #[allow(unused)]
21    fn decode_usize(buf: &mut &[u8]) -> Result<usize, ParseError> {
22        unreachable!("decode usize")
23    }
24}
25
26/// Implemented for all data types that have a fixed size.
27pub trait DataTypeFixedSize {
28    const SIZE: usize;
29}
30
31/// Marks a type as a builder for a given message.
32pub trait BuilderFor: EncoderFor<Self::Message> + Sized {
33    type Message: 'static;
34}
35
36/// Marks a type as a decoder for itself.
37pub trait DecoderFor<'a, F: 'a>: DataType + 'a {
38    fn decode_for(buf: &mut &'a [u8]) -> Result<F, ParseError>;
39}
40
41/// Marks a type as an encoder for a given type.
42pub trait EncoderFor<F: 'static> {
43    fn encode_for(&self, buf: &mut BufWriter<'_>);
44}
45
46/// Helper trait for encodable objects.
47pub trait EncoderForExt {
48    /// Convert this builder into a vector of bytes. This is generally
49    /// not the most efficient way to perform serialization.
50    #[allow(unused)]
51    fn to_vec<F: 'static>(&self) -> Vec<u8>
52    where
53        Self: EncoderFor<F>,
54    {
55        let mut vec = Vec::with_capacity(256);
56        let mut buf = BufWriter::new(&mut vec);
57        EncoderFor::<F>::encode_for(self, &mut buf);
58        match buf.finish() {
59            Ok(size) => {
60                vec.truncate(size);
61                vec
62            }
63            Err(size) => {
64                vec.resize(size, 0);
65                let mut buf = BufWriter::new(&mut vec);
66                EncoderFor::<F>::encode_for(self, &mut buf);
67                // Will not fail this second time
68                let size = buf.finish().unwrap();
69                vec.truncate(size);
70                vec
71            }
72        }
73    }
74
75    /// Encode this builder into a given buffer. If the buffer is
76    /// too small, the function will return the number of bytes
77    /// required to encode the builder.
78    #[allow(unused)]
79    fn encode_buffer<F: 'static>(&self, buf: &mut [u8]) -> Result<usize, usize>
80    where
81        Self: EncoderFor<F>,
82    {
83        let mut writer = BufWriter::new(buf);
84        EncoderFor::<F>::encode_for(self, &mut writer);
85        writer.finish()
86    }
87
88    /// Encode this builder into a given buffer. If the buffer is
89    /// too small, the function will return the number of bytes
90    /// required to encode the builder.
91    #[allow(unused)]
92    fn encode_buffer_uninit<'a, F: 'static>(
93        &self,
94        buf: &'a mut [MaybeUninit<u8>],
95    ) -> Result<&'a mut [u8], usize>
96    where
97        Self: EncoderFor<F>,
98    {
99        let mut writer = BufWriter::new_uninit(buf);
100        EncoderFor::<F>::encode_for(self, &mut writer);
101        writer.finish_buf()
102    }
103
104    #[allow(unused)]
105    fn measure<F: 'static>(&self) -> usize
106    where
107        Self: EncoderFor<F>,
108    {
109        let mut buf = Vec::new();
110        let mut writer = BufWriter::new(&mut buf);
111        EncoderFor::<F>::encode_for(self, &mut writer);
112        writer.finish().unwrap_err()
113    }
114}
115
116impl<T> EncoderForExt for T where T: ?Sized {}
117
118#[derive(derive_more::Error, derive_more::Display, Debug, Clone, Copy, PartialEq, Eq)]
119pub enum ParseError {
120    #[display("Buffer is too short")]
121    TooShort,
122    #[display("Buffer is too long ({_0} extra bytes)")]
123    TooLong(#[error(not(source))] usize),
124    #[display("Invalid data for {_0}: {_1}")]
125    InvalidData(
126        #[error(not(source))] &'static str,
127        #[error(not(source))] usize,
128    ),
129    #[display("Invalid data for field {_0}: {_1}")]
130    InvalidFieldData(
131        #[error(not(source))] &'static str,
132        #[error(not(source))] &'static str,
133    ),
134}
135
136impl<'a, L: DataType, T: DataType> DataType for Array<'a, L, T>
137where
138    T: DecoderFor<'a, T>,
139{
140    const META: StructFieldMeta = declare_meta!(
141        type = Array,
142        constant_size = None,
143        flags = [array]
144    );
145}
146
147impl<'a, L: DataType, T: DataType> DecoderFor<'a, Array<'a, L, T>> for Array<'a, L, T>
148where
149    L: 'a,
150    T: DecoderFor<'a, T>,
151{
152    fn decode_for(buf: &mut &'a [u8]) -> Result<Self, ParseError> {
153        let len = L::decode_usize(buf)?;
154        let orig_buf = *buf;
155        // Primitive types can skip the decode_for call.
156        if T::META.is_primitive {
157            let constant_size = T::META.constant_size.unwrap();
158            let byte_len = constant_size.saturating_mul(len);
159            if buf.len() < byte_len {
160                return Err(ParseError::TooShort);
161            }
162            *buf = &buf[byte_len..];
163            return Ok(Array::new(&orig_buf[..byte_len], len as _));
164        }
165        for _ in 0..len {
166            T::decode_for(buf)?;
167        }
168        let orig_buf = &orig_buf[0..orig_buf.len() - buf.len()];
169        Ok(Array::new(orig_buf, len as _))
170    }
171}
172
173encoder_for_array!(
174    impl <T, L> for Array<'static, L, T> {
175        fn encode_for(&self, buf: &mut BufWriter<'_>, it: impl ExactSizeIterator) {
176            L::encode_usize(buf, it.len());
177            for elem in it {
178                elem.encode_for(buf);
179            }
180        }
181    }
182);
183
184impl<'a, T: DataType> DataType for ZTArray<'a, T>
185where
186    T: DecoderFor<'a, T>,
187{
188    const META: StructFieldMeta = declare_meta!(
189        type = ZTArray,
190        constant_size = None,
191        flags = [array]
192    );
193}
194
195impl<'a, T: DataType> DecoderFor<'a, ZTArray<'a, T>> for ZTArray<'a, T>
196where
197    T: DecoderFor<'a, T>,
198{
199    fn decode_for(buf: &mut &'a [u8]) -> Result<Self, ParseError> {
200        let mut orig_buf = *buf;
201        let mut len = 0;
202
203        // Primitive types can skip the decode_for call and hunt for the 0 byte.
204        if T::META.is_primitive {
205            let constant_size = T::META.constant_size.unwrap();
206            loop {
207                if buf.is_empty() {
208                    return Err(ParseError::TooShort);
209                }
210                if buf[0] == 0 {
211                    break;
212                }
213                *buf = &buf[constant_size..];
214                len += 1;
215            }
216            *buf = &buf[1..];
217            orig_buf = &orig_buf[0..orig_buf.len() - buf.len() - 1];
218            return Ok(ZTArray::new(&orig_buf, len));
219        }
220
221        loop {
222            if buf.is_empty() {
223                return Err(crate::prelude::ParseError::TooShort);
224            }
225            if buf[0] == 0 {
226                orig_buf = &orig_buf[0..orig_buf.len() - buf.len()];
227                *buf = &buf[1..];
228                break;
229            }
230            T::decode_for(buf)?;
231            len += 1;
232        }
233        Ok(ZTArray::new(orig_buf, len))
234    }
235}
236
237encoder_for_array!(
238    impl <T> for ZTArray<'static, T> {
239        fn encode_for(&self, buf: &mut BufWriter<'_>, it: impl Iterator) {
240            for elem in it {
241                elem.encode_for(buf);
242            }
243            buf.write(&[0]);
244        }
245    }
246);
247
248impl<'a, T: DataType> DataType for RestArray<'a, T>
249where
250    T: DecoderFor<'a, T>,
251{
252    const META: StructFieldMeta = declare_meta!(
253        type = RestArray,
254        constant_size = None,
255        flags = [array]
256    );
257}
258
259impl<'a, T: DataType> DecoderFor<'a, RestArray<'a, T>> for RestArray<'a, T>
260where
261    T: DecoderFor<'a, T>,
262{
263    fn decode_for(buf: &mut &'a [u8]) -> Result<Self, ParseError> {
264        let orig_buf = *buf;
265        // Primitive types can skip the decode_for call and compute the number of elements
266        // until the end of the buffer.
267        if T::META.is_primitive {
268            let constant_size = T::META.constant_size.unwrap();
269            let len = buf.len() / constant_size;
270            if buf.len() % constant_size != 0 {
271                return Err(ParseError::TooShort);
272            }
273            *buf = &[];
274            return Ok(RestArray::new(orig_buf, len as _));
275        }
276        let mut len = 0;
277        while !buf.is_empty() {
278            T::decode_for(buf)?;
279            len += 1;
280        }
281        Ok(RestArray::new(orig_buf, len))
282    }
283}
284
285encoder_for_array!(
286    impl <T> for RestArray<'static, T> {
287        fn encode_for(&self, buf: &mut BufWriter<'_>, it: impl Iterator) {
288            for elem in it {
289                elem.encode_for(buf);
290            }
291        }
292    }
293);
294
295impl<const N: usize, T: DataType> DataType for [T; N]
296where
297    for<'a> T: Default + Copy,
298{
299    const META: StructFieldMeta = declare_meta!(
300        type = FixedArray,
301        constant_size = Some(std::mem::size_of::<T>() * N),
302        flags = [array]
303    );
304}
305
306impl<'a, T: DataType, const N: usize> DecoderFor<'a, [T; N]> for [T; N]
307where
308    T: DecoderFor<'a, T> + Default + Copy,
309{
310    fn decode_for(buf: &mut &'a [u8]) -> Result<Self, ParseError> {
311        let mut res = [T::default(); N];
312        for res in res.iter_mut().take(N) {
313            *res = T::decode_for(buf)?;
314        }
315        Ok(res)
316    }
317}
318
319impl<const N: usize, T: DataType> DataTypeFixedSize for [T; N] {
320    const SIZE: usize = std::mem::size_of::<T>() * N;
321}
322
323impl<const N: usize, T: DataType + 'static, U: EncoderFor<T>> EncoderFor<[T; N]> for [U; N] {
324    fn encode_for(&self, buf: &mut BufWriter<'_>) {
325        for elem in self {
326            U::encode_for(elem, buf);
327        }
328    }
329}
330
331/// Implements [`DataType`] and [`DataTypeFixedSize`] for tuples.
332macro_rules! tuple_type {
333    () => {};
334    ($head:ident $(, $tail:ident)*) => {
335        impl <$head: DataType, $($tail: DataType),*> DataType for ($head, $($tail),*) {
336            const META: StructFieldMeta = declare_meta!(type = Tuple, constant_size = None, flags = []);
337        }
338
339        impl <$head: DataType, $($tail: DataType),*> DataTypeFixedSize for ($head, $($tail),*) where $head: DataTypeFixedSize, $($tail: DataTypeFixedSize),* {
340            const SIZE: usize = $head::SIZE $(+ $tail::SIZE)*;
341        }
342
343        $crate::paste!(
344            /// Homomorphic mapping: If A: DecoderFor<A_X>, B: DecoderFor<B_X>, then (A, B): DecoderFor<(A_X, B_X)>
345            impl <'a,$head: DataType, $($tail: DataType),*> DecoderFor<'a, ($head, $($tail),*)> for ($head, $($tail),*) where $head: DecoderFor<'a, $head>, $($tail: DecoderFor<'a, $tail>),* {
346                fn decode_for(buf: &mut &'a [u8]) -> Result<Self, ParseError> {
347                    Ok((
348                        $head::decode_for(buf)?,
349                        $($tail::decode_for(buf)?),*
350                    ))
351                }
352            }
353
354            /// Homomorphic mapping: If A: EncoderFor<A_X>, B: EncoderFor<B_X>, then (A, B): EncoderFor<(A_X, B_X)>
355            impl <$head, [<$head X>]: 'static, $($tail, [<$tail X>]: 'static),*>
356                EncoderFor<([<$head X>], $([<$tail X>]),*)> for ($head, $($tail),*)
357
358                where $head: EncoderFor<[<$head X>]>, $($tail: EncoderFor<[<$tail X>]>),* {
359
360                fn encode_for(&self, buf: &mut BufWriter<'_>) {
361                    #[allow(non_snake_case)]
362                    let ($head, $($tail),*) = self;
363                    EncoderFor::<[<$head X>]>::encode_for($head, buf);
364                    $(
365                        EncoderFor::<[<$tail X>]>::encode_for($tail, buf);
366                    )*
367                }
368            }
369        );
370
371        // recurse
372        tuple_type!($($tail),*);
373    };
374}
375
376// Up to 52 fields seems reasonable.
377tuple_type!(
378    A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z, A1, B1, C1, D1,
379    E1, F1, G1, H1, I1, J1, K1, L1, M1, N1, O1, P1, Q1, R1, S1, T1, U1, V1, W1, X1, Y1, Z1
380);
381
382declare_type!(DataType, Rest<'a>, builder: &'a [u8],
383{}
384);
385
386impl<'a> DecoderFor<'a, Rest<'a>> for Rest<'a> {
387    fn decode_for(buf: &mut &'a [u8]) -> Result<Self, ParseError> {
388        let res = Rest::new(buf);
389        *buf = &[];
390        Ok(res)
391    }
392}
393
394impl<T> EncoderFor<Rest<'static>> for T
395where
396    T: AsRef<[u8]>,
397{
398    fn encode_for(&self, buf: &mut BufWriter<'_>) {
399        buf.write(self.as_ref());
400    }
401}
402
403declare_type!(DataType, LString<'a>, builder: &'a str, {});
404declare_type!(DataType, ZTString<'a>, builder: &'a str, {});
405declare_type!(DataType, RestString<'a>, builder: &'a str, {});
406
407impl<'a, A> DecoderFor<'a, ArrayString<'a, A>> for ArrayString<'a, A>
408where
409    A: ArrayExt<'a>,
410    A: DecoderFor<'a, A>,
411    A: DataType,
412    Self: DataType,
413{
414    fn decode_for(buf: &mut &'a [u8]) -> Result<ArrayString<'a, A>, ParseError> {
415        let arr = A::decode_for(buf)?;
416        Ok(ArrayString::new(arr.into_slice()))
417    }
418}
419
420impl<T, A> EncoderFor<ArrayString<'static, A>> for T
421where
422    for<'any> &'any T: AsRef<str>,
423    A: AsRef<[u8]>,
424    A: 'static,
425    for<'any> &'any [u8]: EncoderFor<A>,
426{
427    fn encode_for(&self, buf: &mut BufWriter<'_>) {
428        let bytes = self.as_ref().as_bytes();
429        bytes.encode_for(buf);
430    }
431}
432
433declare_type!(DataType, Encoded<'a>, builder: Encoded<'a>, {});
434
435impl<'a> DecoderFor<'a, Encoded<'a>> for Encoded<'a> {
436    fn decode_for(buf: &mut &'a [u8]) -> Result<Self, ParseError> {
437        if let Some((len, array)) = buf.split_first_chunk::<{ std::mem::size_of::<i32>() }>() {
438            let len = i32::from_be_bytes(*len);
439            if len == -1 {
440                *buf = array;
441                Ok(Encoded::Null)
442            } else if len < 0 {
443                Err(ParseError::InvalidData("Encoded", len as usize))
444            } else if array.len() < len as _ {
445                Err(ParseError::TooShort)
446            } else {
447                *buf = &array[len as usize..];
448                Ok(Encoded::Value(&array[..len as usize]))
449            }
450        } else {
451            Err(ParseError::TooShort)
452        }
453    }
454}
455
456impl<T> EncoderFor<Encoded<'static>> for Option<T>
457where
458    T: AsRef<[u8]>,
459{
460    fn encode_for(&self, buf: &mut BufWriter<'_>) {
461        match self {
462            Some(value) => buf.write(value.as_ref()),
463            None => buf.write(&(-1_i32).to_be_bytes()),
464        }
465    }
466}
467
468impl EncoderFor<Encoded<'static>> for Encoded<'_> {
469    fn encode_for(&self, buf: &mut BufWriter<'_>) {
470        match self {
471            Encoded::Null => buf.write(&(-1_i32).to_be_bytes()),
472            Encoded::Value(value) => {
473                let len: i32 = value.len() as _;
474                buf.write(&len.to_be_bytes());
475                buf.write(value);
476            }
477        }
478    }
479}
480
481impl EncoderFor<Encoded<'static>> for &'_ Encoded<'_> {
482    fn encode_for(&self, buf: &mut BufWriter<'_>) {
483        match self {
484            Encoded::Null => buf.write(&(-1_i32).to_be_bytes()),
485            Encoded::Value(value) => {
486                let len: i32 = value.len() as _;
487                buf.write(&len.to_be_bytes());
488                buf.write(value);
489            }
490        }
491    }
492}
493
494declare_type!(DataType, Length, flags = [length], {
495    fn to_usize(value: usize) -> Length {
496        Length(value as _)
497    }
498    fn from_usize(value: Length) -> usize {
499        value.0 as usize
500    }
501});
502
503impl<'a> DecoderFor<'a, Length> for Length {
504    fn decode_for(buf: &mut &'a [u8]) -> Result<Self, ParseError> {
505        i32::decode_for(buf).map(Length)
506    }
507}
508
509impl EncoderFor<Length> for u32 {
510    fn encode_for(&self, buf: &mut BufWriter<'_>) {
511        buf.write(&self.to_be_bytes());
512    }
513}
514
515impl EncoderFor<Length> for Length {
516    fn encode_for(&self, buf: &mut BufWriter<'_>) {
517        buf.write(&self.0.to_be_bytes());
518    }
519}
520
521declare_type!(DataType, Uuid, {});
522
523impl<'a> DecoderFor<'a, Uuid> for Uuid {
524    fn decode_for(buf: &mut &'a [u8]) -> Result<Self, ParseError> {
525        <[u8; 16] as DecoderFor<'a, [u8; 16]>>::decode_for(buf).map(Uuid::from_bytes)
526    }
527}
528
529impl EncoderFor<Uuid> for &'_ Uuid {
530    fn encode_for(&self, buf: &mut BufWriter<'_>) {
531        buf.write(&self.into_bytes());
532    }
533}
534
535impl EncoderFor<Uuid> for Uuid {
536    fn encode_for(&self, buf: &mut BufWriter<'_>) {
537        buf.write(&self.into_bytes());
538    }
539}
540
541impl<T> DataType for LengthPrefixed<T>
542where
543    T: DataType,
544{
545    const META: StructFieldMeta = T::META;
546}
547
548impl<'a, T> DecoderFor<'a, LengthPrefixed<T>> for LengthPrefixed<T>
549where
550    T: DecoderFor<'a, T>,
551{
552    fn decode_for(buf: &mut &'a [u8]) -> Result<Self, ParseError> {
553        let len = u32::decode_for(buf)?;
554        if len > buf.len() as u32 {
555            return Err(ParseError::TooShort);
556        }
557        let mut inner_buf = &buf[..len as usize];
558        *buf = &buf[len as usize..];
559        // The inner object must consume the entire buffer.
560        let inner = T::decode_for(&mut inner_buf)?;
561        if inner_buf.len() != 0 {
562            return Err(ParseError::InvalidData("LengthPrefixed", inner_buf.len()));
563        }
564        Ok(LengthPrefixed(inner))
565    }
566}
567
568impl<T, U> EncoderFor<LengthPrefixed<T>> for LengthPrefixed<U>
569where
570    U: EncoderFor<T>,
571    T: 'static,
572{
573    fn encode_for(&self, buf: &mut BufWriter<'_>) {
574        let offset = buf.size();
575        U::encode_for(&self.0, buf);
576        let len = buf.size() - offset;
577        buf.write_rewind(offset, &len.to_be_bytes());
578    }
579}
580
581declare_type!(u8);
582declare_type!(u16);
583declare_type!(u32);
584declare_type!(u64);
585declare_type!(u128);
586declare_type!(i8);
587declare_type!(i16);
588declare_type!(i32);
589declare_type!(i64);
590declare_type!(i128);
591
592declare_type!(f32);
593declare_type!(f64);
594
595#[cfg(test)]
596mod tests {
597    use super::*;
598
599    static_assertions::assert_impl_all!(u8: DataType, DataTypeFixedSize);
600    static_assertions::assert_impl_all!([u8; 4]: DataType, DataTypeFixedSize, DecoderFor<'static, [u8; 4]>);
601    static_assertions::assert_impl_all!((u8, u8): DataType, DataTypeFixedSize, EncoderFor<(u8, u8)>);
602
603    static_assertions::assert_impl_all!(&'static str: EncoderFor<LString<'static>>);
604    static_assertions::assert_impl_all!(String: EncoderFor<LString<'static>>);
605    static_assertions::assert_impl_all!(&'static String: EncoderFor<LString<'static>>);
606}