fifthtry_serde_sqlite_jsonb/
ser.rs

1use crate::{
2    error::{Error, Result},
3    header::ElementType,
4};
5use serde::ser::{self, Serialize};
6use std::io::Write;
7
8#[derive(Debug, Default)]
9pub struct Serializer {
10    buffer: Vec<u8>,
11}
12
13/// Serialize a value into a JSONB byte array
14pub fn to_vec<T>(value: &T) -> Result<Vec<u8>>
15where
16    T: Serialize,
17{
18    let mut serializer = Serializer::default();
19    value.serialize(&mut serializer)?;
20    Ok(serializer.buffer)
21}
22
23/// Helper struct to write JSONB data, then finalize the header to its minimal size
24pub struct JsonbWriter<'a> {
25    buffer: &'a mut Vec<u8>,
26    header_start: u64,
27}
28
29impl<'a> JsonbWriter<'a> {
30    fn new(buffer: &'a mut Vec<u8>, element_type: ElementType) -> Self {
31        let header_start = buffer.len() as u64;
32        buffer.extend_from_slice(&[u8::from(element_type); 9]);
33        Self {
34            buffer,
35            header_start,
36        }
37    }
38    fn finalize(self) {
39        let data_start = self.header_start as usize + 9;
40        let data_end = self.buffer.len();
41        let payload_size = data_end - data_start;
42        let header = &mut self.buffer
43            [(self.header_start as usize)..(self.header_start as usize) + 9];
44        let head_len = if payload_size <= 11 {
45            header[0] |= (payload_size as u8) << 4;
46            1
47        } else if payload_size <= 0xff {
48            header[0] |= 0xc0;
49            header[1] = payload_size as u8;
50            2
51        } else if payload_size <= 0xffff {
52            header[0] |= 0xd0;
53            header[1..3].copy_from_slice(&(payload_size as u16).to_be_bytes());
54            3
55        } else if payload_size <= 0xffffffff {
56            header[0] |= 0xe0;
57            header[1..5].copy_from_slice(&(payload_size as u32).to_be_bytes());
58            5
59        } else {
60            header[0] |= 0xf0;
61            header[1..9].copy_from_slice(&payload_size.to_be_bytes());
62            9
63        };
64        if head_len < 9 {
65            self.buffer.copy_within(
66                data_start..data_end,
67                self.header_start as usize + head_len,
68            );
69            self.buffer
70                .truncate(self.header_start as usize + head_len + payload_size);
71        }
72    }
73}
74
75impl Serializer {
76    fn write_header_nodata(&mut self, element_type: ElementType) -> Result<()> {
77        self.buffer.push(u8::from(element_type));
78        Ok(())
79    }
80
81    fn write_displayable(
82        &mut self,
83        element_type: ElementType,
84        data: impl std::fmt::Display,
85    ) -> Result<()> {
86        let mut w = JsonbWriter::new(&mut self.buffer, element_type);
87        write!(&mut w.buffer, "{}", data)?;
88        w.finalize();
89        Ok(())
90    }
91}
92
93impl<'a> ser::Serializer for &'a mut Serializer {
94    type Ok = ();
95
96    type Error = Error;
97
98    type SerializeSeq = JsonbWriter<'a>;
99
100    type SerializeTuple = JsonbWriter<'a>;
101
102    type SerializeTupleStruct = JsonbWriter<'a>;
103
104    type SerializeTupleVariant = EnumVariantSerializer<'a>;
105
106    type SerializeMap = JsonbWriter<'a>;
107
108    type SerializeStruct = JsonbWriter<'a>;
109
110    type SerializeStructVariant = EnumVariantSerializer<'a>;
111
112    fn serialize_bool(self, v: bool) -> Result<Self::Ok> {
113        self.write_header_nodata(if v {
114            ElementType::True
115        } else {
116            ElementType::False
117        })
118    }
119
120    fn serialize_i8(self, v: i8) -> Result<Self::Ok> {
121        self.write_displayable(ElementType::Int, v)
122    }
123
124    fn serialize_i16(self, v: i16) -> Result<Self::Ok> {
125        self.write_displayable(ElementType::Int, v)
126    }
127
128    fn serialize_i32(self, v: i32) -> Result<Self::Ok> {
129        self.write_displayable(ElementType::Int, v)
130    }
131
132    fn serialize_i64(self, v: i64) -> Result<Self::Ok> {
133        self.write_displayable(ElementType::Int, v)
134    }
135
136    fn serialize_u8(self, v: u8) -> Result<Self::Ok> {
137        self.write_displayable(ElementType::Int, v)
138    }
139
140    fn serialize_u16(self, v: u16) -> Result<Self::Ok> {
141        self.write_displayable(ElementType::Int, v)
142    }
143
144    fn serialize_u32(self, v: u32) -> Result<Self::Ok> {
145        self.write_displayable(ElementType::Int, v)
146    }
147
148    fn serialize_u64(self, v: u64) -> Result<Self::Ok> {
149        self.write_displayable(ElementType::Int, v)
150    }
151
152    fn serialize_f32(self, v: f32) -> Result<Self::Ok> {
153        self.write_displayable(ElementType::Float, v)
154    }
155
156    fn serialize_f64(self, v: f64) -> Result<Self::Ok> {
157        self.write_displayable(ElementType::Float, v)
158    }
159
160    fn serialize_char(self, v: char) -> Result<Self::Ok> {
161        self.write_displayable(ElementType::TextRaw, v)
162    }
163
164    fn serialize_str(self, v: &str) -> Result<Self::Ok> {
165        self.write_displayable(ElementType::TextRaw, v)
166    }
167
168    fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok> {
169        use serde::ser::SerializeSeq;
170        let mut s = self.serialize_seq(Some(v.len()))?;
171        for byte in v {
172            s.serialize_element(byte)?;
173        }
174        s.end()
175    }
176
177    fn serialize_none(self) -> Result<Self::Ok> {
178        self.serialize_unit()
179    }
180
181    fn serialize_some<T: ?Sized + Serialize>(
182        self,
183        value: &T,
184    ) -> Result<Self::Ok> {
185        T::serialize(value, self)
186    }
187
188    fn serialize_unit(self) -> Result<Self::Ok> {
189        self.write_header_nodata(ElementType::Null)
190    }
191
192    fn serialize_unit_struct(self, _name: &'static str) -> Result<Self::Ok> {
193        self.serialize_unit()
194    }
195
196    fn serialize_unit_variant(
197        self,
198        _name: &'static str,
199        _variant_index: u32,
200        variant: &'static str,
201    ) -> Result<Self::Ok> {
202        self.serialize_str(variant)
203    }
204
205    fn serialize_newtype_struct<T: ?Sized + Serialize>(
206        self,
207        _name: &'static str,
208        _value: &T,
209    ) -> Result<Self::Ok> {
210        self.serialize_unit()
211    }
212
213    fn serialize_newtype_variant<T: ?Sized + Serialize>(
214        self,
215        _name: &'static str,
216        _variant_index: u32,
217        variant: &'static str,
218        value: &T,
219    ) -> Result<Self::Ok> {
220        let mut map = self.serialize_map(Some(1))?;
221        serde::ser::SerializeMap::serialize_key(&mut map, variant)?;
222        serde::ser::SerializeMap::serialize_value(&mut map, value)?;
223        serde::ser::SerializeMap::end(map)
224    }
225
226    fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq> {
227        Ok(JsonbWriter::new(&mut self.buffer, ElementType::Array))
228    }
229
230    fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple> {
231        Ok(JsonbWriter::new(&mut self.buffer, ElementType::Array))
232    }
233
234    fn serialize_tuple_struct(
235        self,
236        _name: &'static str,
237        len: usize,
238    ) -> Result<Self::SerializeTupleStruct> {
239        self.serialize_tuple(len)
240    }
241
242    fn serialize_tuple_variant(
243        self,
244        _name: &'static str,
245        _variant_index: u32,
246        variant: &'static str,
247        _len: usize,
248    ) -> Result<Self::SerializeTupleVariant> {
249        Ok(EnumVariantSerializer::new(
250            &mut self.buffer,
251            variant,
252            ElementType::Array,
253        ))
254    }
255
256    fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap> {
257        Ok(JsonbWriter::new(&mut self.buffer, ElementType::Object))
258    }
259
260    fn serialize_struct(
261        self,
262        _name: &'static str,
263        len: usize,
264    ) -> Result<Self::SerializeStruct> {
265        self.serialize_map(Some(len))
266    }
267
268    fn serialize_struct_variant(
269        self,
270        _name: &'static str,
271        _variant_index: u32,
272        _variant: &'static str,
273        _len: usize,
274    ) -> Result<Self::SerializeStructVariant> {
275        Ok(EnumVariantSerializer::new(
276            &mut self.buffer,
277            _variant,
278            ElementType::Object,
279        ))
280    }
281}
282
283impl<'a> ser::SerializeSeq for JsonbWriter<'a> {
284    type Ok = ();
285
286    type Error = Error;
287
288    fn serialize_element<T: ?Sized + Serialize>(
289        &mut self,
290        value: &T,
291    ) -> Result<()> {
292        let mut serializer = Serializer::default();
293        std::mem::swap(self.buffer, &mut serializer.buffer);
294        let r = value.serialize(&mut serializer);
295        std::mem::swap(self.buffer, &mut serializer.buffer);
296        r
297    }
298
299    fn end(self) -> Result<Self::Ok> {
300        self.finalize();
301        Ok(())
302    }
303}
304
305impl<'a> ser::SerializeTuple for JsonbWriter<'a> {
306    type Ok = ();
307
308    type Error = Error;
309
310    fn serialize_element<T: ?Sized + Serialize>(
311        &mut self,
312        value: &T,
313    ) -> Result<()> {
314        <Self as ser::SerializeSeq>::serialize_element(self, value)
315    }
316
317    fn end(self) -> Result<Self::Ok> {
318        <Self as ser::SerializeSeq>::end(self)
319    }
320}
321
322impl<'a> ser::SerializeTupleStruct for JsonbWriter<'a> {
323    type Ok = ();
324
325    type Error = Error;
326
327    fn serialize_field<T: ?Sized + Serialize>(
328        &mut self,
329        value: &T,
330    ) -> std::prelude::v1::Result<(), Self::Error> {
331        <Self as ser::SerializeTuple>::serialize_element(self, value)
332    }
333
334    fn end(self) -> Result<Self::Ok> {
335        <Self as ser::SerializeTuple>::end(self)
336    }
337}
338
339/// Serializes an enum variant as an object with a single key for the variant name
340/// and an array of the tuple fields or a map as the value.
341/// MyEnum::Variant(1, 2) -> {"Variant": [1, 2]}
342/// MyEnum::Variant { field1: 1, field2: 2 } -> {"Variant": {"field1": 1, "field2": 2}}
343/// We need to keep track of two jsonb headers, one for the inner array or map, and one for the object.
344pub struct EnumVariantSerializer<'a> {
345    map_header_start: u64,
346    inner_jsonb_writer: JsonbWriter<'a>,
347}
348
349impl<'a> EnumVariantSerializer<'a> {
350    fn new(
351        buffer: &'a mut Vec<u8>,
352        variant: &'static str,
353        inner_element_type: ElementType,
354    ) -> Self {
355        let mut map_jsonb_writer =
356            JsonbWriter::new(buffer, ElementType::Object);
357        ser::SerializeMap::serialize_key(&mut map_jsonb_writer, variant)
358            .unwrap();
359        let map_header_start = map_jsonb_writer.header_start;
360        let inner_jsonb_writer = JsonbWriter::new(buffer, inner_element_type);
361        Self {
362            map_header_start,
363            inner_jsonb_writer,
364        }
365    }
366}
367
368impl<'a> ser::SerializeTupleVariant for EnumVariantSerializer<'a> {
369    type Ok = ();
370
371    type Error = Error;
372
373    fn serialize_field<T: ?Sized + Serialize>(
374        &mut self,
375        value: &T,
376    ) -> Result<()> {
377        ser::SerializeSeq::serialize_element(
378            &mut self.inner_jsonb_writer,
379            value,
380        )
381    }
382
383    fn end(self) -> Result<Self::Ok> {
384        ser::SerializeSeq::end(JsonbWriter {
385            buffer: self.inner_jsonb_writer.buffer,
386            header_start: self.inner_jsonb_writer.header_start,
387        })?;
388        ser::SerializeMap::end(JsonbWriter {
389            buffer: self.inner_jsonb_writer.buffer,
390            header_start: self.map_header_start,
391        })
392    }
393}
394
395impl<'a> ser::SerializeMap for JsonbWriter<'a> {
396    type Ok = ();
397
398    type Error = Error;
399
400    fn serialize_key<T: ?Sized + Serialize>(&mut self, key: &T) -> Result<()> {
401        <Self as ser::SerializeSeq>::serialize_element(self, key)
402    }
403
404    fn serialize_value<T: ?Sized + Serialize>(
405        &mut self,
406        value: &T,
407    ) -> Result<()> {
408        <Self as ser::SerializeSeq>::serialize_element(self, value)
409    }
410
411    fn end(self) -> Result<Self::Ok> {
412        self.finalize();
413        Ok(())
414    }
415}
416
417impl<'a> ser::SerializeStruct for JsonbWriter<'a> {
418    type Ok = ();
419
420    type Error = Error;
421
422    fn serialize_field<T: ?Sized + Serialize>(
423        &mut self,
424        key: &'static str,
425        value: &T,
426    ) -> Result<()> {
427        <Self as ser::SerializeMap>::serialize_key(self, key)?;
428        <Self as ser::SerializeMap>::serialize_value(self, value)
429    }
430
431    fn end(self) -> Result<Self::Ok> {
432        self.finalize();
433        Ok(())
434    }
435}
436
437impl<'a> ser::SerializeStructVariant for EnumVariantSerializer<'a> {
438    type Ok = ();
439
440    type Error = Error;
441
442    fn serialize_field<T: ?Sized + Serialize>(
443        &mut self,
444        key: &'static str,
445        value: &T,
446    ) -> Result<()> {
447        ser::SerializeTupleVariant::serialize_field(self, key)?;
448        ser::SerializeTupleVariant::serialize_field(self, value)
449    }
450
451    fn end(self) -> Result<Self::Ok> {
452        ser::SerializeTupleVariant::end(self)
453    }
454}
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459
460    #[test]
461    fn test_serialize_u8() {
462        assert_eq!(to_vec(&42u8).unwrap(), b"\x2342");
463    }
464
465    #[test]
466    fn test_serialize_i64() {
467        assert_eq!(
468            to_vec(&1234567890123456789i64).unwrap(),
469            b"\xc3\x131234567890123456789"
470        );
471    }
472
473    #[test]
474    fn test_serialize_bool() {
475        assert_eq!(to_vec(&true).unwrap(), b"\x01");
476        assert_eq!(to_vec(&false).unwrap(), b"\x02");
477    }
478
479    #[test]
480    fn test_serialize_sring() {
481        assert_eq!(to_vec(&"hello").unwrap(), b"\x5ahello");
482    }
483
484    fn assert_long_str(repeats: u64, expected_header: &[u8]) {
485        let long_str = "x".repeat(repeats as usize);
486        assert_eq!(
487            to_vec(&long_str).unwrap(),
488            [&expected_header[..], &long_str.as_bytes()].concat()
489        );
490    }
491
492    #[test]
493    fn test_serialize_various_string_lengths() {
494        assert_long_str(0x0, b"\x0a");
495        assert_long_str(0x1, b"\x1a");
496        assert_long_str(0xb, b"\xba");
497        assert_long_str(0xc, b"\xca\x0c");
498        assert_long_str(0xf, b"\xca\x0f");
499        assert_long_str(0x100, b"\xda\x01\x00");
500        assert_long_str(0xffff, b"\xda\xff\xff");
501        assert_long_str(0x01_23_45_67, b"\xea\x01\x23\x45\x67");
502        // disabled for test performance:
503        // assert_long_str(0x01_0000_0000, b"\xfa\x00\x00\x00\x01\x00\x00\x00\x00");
504    }
505
506    #[test]
507    fn test_serialize_array() {
508        assert_eq!(
509            to_vec(&Vec::<String>::new()).unwrap(),
510            b"\x0b",
511            "empty array"
512        );
513        assert_eq!(to_vec(&vec![true, false]).unwrap(), b"\x2b\x01\x02");
514    }
515
516    #[test]
517    fn test_serialize_tuple() {
518        assert_eq!(to_vec(&(true, 1, 2)).unwrap(), b"\x5b\x01\x131\x132");
519    }
520
521    #[test]
522    fn test_serialize_tuple_struct() {
523        #[derive(serde_derive::Serialize)]
524        struct TupleStruct(String, f32);
525
526        assert_eq!(
527            to_vec(&TupleStruct("hello".to_string(), 3.14)).unwrap(),
528            b"\xbb\x5ahello\x453.14"
529        );
530    }
531
532    #[test]
533    fn test_serialize_struct() {
534        #[derive(serde_derive::Serialize)]
535        struct TestStruct {
536            smol: char,
537            long_long_long_long: u64,
538        }
539        let test_struct = TestStruct {
540            smol: 'X',
541            long_long_long_long: 42,
542        };
543        assert_eq!(
544            to_vec(&test_struct).unwrap(),
545            b"\xcc\x1f\x4asmol\x1aX\xca\x13long_long_long_long\x2342"
546        );
547    }
548
549    #[test]
550    fn test_serialize_map() {
551        let mut test_map = std::collections::HashMap::new();
552        test_map.insert("k".to_string(), false);
553        assert_eq!(to_vec(&test_map).unwrap(), b"\x3c\x1ak\x02",);
554    }
555
556    #[test]
557    fn test_serialize_empty_map() {
558        let test_map = std::collections::HashMap::<String, ()>::new();
559        assert_eq!(to_vec(&test_map).unwrap(), b"\x0c",);
560    }
561
562    #[test]
563    fn test_serialize_option() {
564        assert_eq!(to_vec(&Some(42)).unwrap(), b"\x2342");
565        assert_eq!(to_vec(&Option::<i32>::None).unwrap(), b"\x00");
566    }
567
568    #[test]
569    fn test_serialize_unit() {
570        assert_eq!(to_vec(&()).unwrap(), b"\x00");
571    }
572
573    #[test]
574    fn test_serialize_unit_struct() {
575        #[derive(serde_derive::Serialize)]
576        struct UnitStruct;
577
578        assert_eq!(to_vec(&UnitStruct).unwrap(), b"\x00");
579    }
580
581    #[test]
582    fn test_serialize_enum_unit_variants() {
583        #[derive(serde_derive::Serialize)]
584        enum Enum {
585            A,
586            B,
587        }
588
589        assert_eq!(to_vec(&Enum::A).unwrap(), b"\x1aA");
590        assert_eq!(to_vec(&Enum::B).unwrap(), b"\x1aB");
591    }
592
593    #[test]
594    fn test_serialize_enum_newtype_variant() {
595        #[derive(serde_derive::Serialize)]
596        enum Enum {
597            A(i32),
598        }
599
600        assert_eq!(to_vec(&Enum::A(42)).unwrap(), b"\x5c\x1aA\x2342");
601    }
602
603    #[test]
604    fn test_serialize_enum_tuple_variant() {
605        #[derive(serde_derive::Serialize)]
606        enum Enum {
607            A(i32, i32),
608        }
609
610        assert_eq!(to_vec(&Enum::A(1, 2)).unwrap(), b"\x7c\x1aA\x4b\x131\x132");
611    }
612
613    #[test]
614    fn test_serialize_enum_struct_variant() {
615        #[derive(serde_derive::Serialize)]
616        enum E {
617            S { x: bool },
618        }
619        let test_struct = E::S { x: true };
620        assert_eq!(to_vec(&test_struct).unwrap(), b"\x6c\x1aS\x3c\x1ax\x01");
621    }
622}