serde_beve/ser/
map.rs

1use super::{SeqSerializer, Serializer};
2use crate::{Value, error::Error, headers::ObjectKind};
3use serde::{
4    Serialize,
5    ser::{SerializeMap, SerializeStruct, SerializeStructVariant},
6};
7use std::io::Write;
8
9pub struct MapSerializer<'a, W: Write> {
10    serializer: &'a mut Serializer<W>,
11    kind: Option<ObjectKind>,
12    keys: Vec<Value>,
13    values: Vec<Value>,
14    key: bool,
15}
16
17impl<'a, W: Write> MapSerializer<'a, W> {
18    pub fn new(serializer: &'a mut Serializer<W>, kind: Option<ObjectKind>) -> Self {
19        serializer.write = false;
20        Self {
21            serializer,
22            kind,
23            keys: Vec::new(),
24            values: Vec::new(),
25            key: false,
26        }
27    }
28
29    fn ensure_kind(&mut self, expected: ObjectKind) -> Result<(), Error> {
30        if self.key {
31            match self.kind {
32                None => self.kind = Some(expected),
33                Some(found) => {
34                    if found != expected {
35                        return Err(Error::MismatchedKeyType { expected, found });
36                    }
37                }
38            }
39        }
40        Ok(())
41    }
42}
43
44impl<'a, W: Write> SerializeMap for MapSerializer<'a, W> {
45    type Ok = Value;
46    type Error = Error;
47
48    fn serialize_key<T>(&mut self, key: &T) -> Result<(), Self::Error>
49    where
50        T: ?Sized + serde::Serialize,
51    {
52        self.key = true;
53        let key = key.serialize(&mut *self)?;
54        self.keys.push(key);
55        Ok(())
56    }
57
58    fn serialize_value<T>(&mut self, value: &T) -> Result<(), Self::Error>
59    where
60        T: ?Sized + serde::Serialize,
61    {
62        self.key = false;
63        let value = value.serialize(&mut *self)?;
64        self.values.push(value);
65        Ok(())
66    }
67
68    fn end(self) -> Result<Self::Ok, Self::Error> {
69        macro_rules! convert_object {
70            ( $( $kind:ident => $object:ident ),* $(,)? ) => {
71                match self.kind {
72                    None => Value::StringObject(vec![]),
73                    $(
74                        Some(ObjectKind::$kind) => {
75                            let keys = self.keys.into_iter().map(|v| match v {
76                                Value::$kind(v) => v,
77                                _ => unreachable!(),
78                            });
79                            let fields = keys.zip(self.values).collect();
80                            Value::$object(fields)
81                        }
82                    )*
83                }
84            }
85        }
86
87        let value = convert_object! {
88            String => StringObject,
89            I8 => I8Object,
90            I16 => I16Object,
91            I32 => I32Object,
92            I64 => I64Object,
93            I128 => I128Object,
94            U8 => U8Object,
95            U16 => U16Object,
96            U32 => U32Object,
97            U64 => U64Object,
98            U128 => U128Object,
99        };
100
101        self.serializer.write = true;
102        self.serializer.serialize_value(&value)?;
103        Ok(value)
104    }
105}
106
107impl<'a, W: Write> SerializeStruct for MapSerializer<'a, W> {
108    type Ok = Value;
109    type Error = Error;
110
111    fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error>
112    where
113        T: ?Sized + serde::Serialize,
114    {
115        self.serialize_key(key)?;
116        self.serialize_value(value)?;
117        Ok(())
118    }
119
120    fn end(self) -> Result<Self::Ok, Self::Error> {
121        SerializeMap::end(self)
122    }
123}
124
125impl<'a, W: Write> SerializeStructVariant for MapSerializer<'a, W> {
126    type Ok = Value;
127    type Error = Error;
128
129    fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error>
130    where
131        T: ?Sized + Serialize,
132    {
133        self.serialize_key(key)?;
134        self.serialize_value(value)?;
135        Ok(())
136    }
137
138    fn end(self) -> Result<Self::Ok, Self::Error> {
139        SerializeMap::end(self)
140    }
141}
142
143macro_rules! serialize_map_type {
144    ($fn:ident, $ty:ty, $kind:ident) => {
145        fn $fn(self, v: $ty) -> Result<Self::Ok, Self::Error> {
146            self.ensure_kind(ObjectKind::$kind)?;
147            self.$fn(v)
148        }
149    };
150    ($fn:ident, $ty:ty) => {
151        fn $fn(self, v: $ty) -> Result<Self::Ok, Self::Error> {
152            if self.key {
153                Err(Error::InvalidKey)
154            } else {
155                self.serializer.$fn(v)
156            }
157        }
158    };
159}
160
161impl<'a, 'b, W: Write> serde::Serializer for &'b mut MapSerializer<'a, W> {
162    type Ok = Value;
163    type Error = Error;
164
165    type SerializeSeq = SeqSerializer<'b, W>;
166    type SerializeTuple = SeqSerializer<'b, W>;
167    type SerializeTupleStruct = SeqSerializer<'b, W>;
168    type SerializeTupleVariant = SeqSerializer<'b, W>;
169    type SerializeMap = MapSerializer<'b, W>;
170    type SerializeStruct = MapSerializer<'b, W>;
171    type SerializeStructVariant = MapSerializer<'b, W>;
172
173    serialize_map_type!(serialize_bool, bool);
174    serialize_map_type!(serialize_i8, i8, I8);
175    serialize_map_type!(serialize_i16, i16, I16);
176    serialize_map_type!(serialize_i32, i32, I32);
177    serialize_map_type!(serialize_i64, i64, I64);
178    serialize_map_type!(serialize_i128, i128, I128);
179    serialize_map_type!(serialize_u8, u8, U8);
180    serialize_map_type!(serialize_u16, u16, U16);
181    serialize_map_type!(serialize_u32, u32, U32);
182    serialize_map_type!(serialize_u64, u64, U64);
183    serialize_map_type!(serialize_u128, u128, U128);
184    serialize_map_type!(serialize_f32, f32);
185    serialize_map_type!(serialize_f64, f64);
186    serialize_map_type!(serialize_str, &str, String);
187    serialize_map_type!(serialize_bytes, &[u8]);
188
189    fn serialize_char(self, v: char) -> Result<Self::Ok, Self::Error> {
190        self.serialize_str(&v.to_string())
191    }
192
193    fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
194        self.serialize_unit()
195    }
196
197    fn serialize_some<T>(self, value: &T) -> Result<Self::Ok, Self::Error>
198    where
199        T: ?Sized + serde::Serialize,
200    {
201        value.serialize(self)
202    }
203
204    fn serialize_unit(self) -> Result<Self::Ok, Self::Error> {
205        if self.key {
206            Err(Error::InvalidKey)
207        } else {
208            self.serializer.serialize_unit()
209        }
210    }
211
212    fn serialize_unit_struct(self, _name: &'static str) -> Result<Self::Ok, Self::Error> {
213        self.serialize_unit()
214    }
215
216    fn serialize_unit_variant(
217        self,
218        name: &'static str,
219        variant_index: u32,
220        variant: &'static str,
221    ) -> Result<Self::Ok, Self::Error> {
222        if self.key {
223            Err(Error::InvalidKey)
224        } else {
225            self.serializer
226                .serialize_unit_variant(name, variant_index, variant)
227        }
228    }
229
230    fn serialize_newtype_struct<T>(
231        self,
232        name: &'static str,
233        value: &T,
234    ) -> Result<Self::Ok, Self::Error>
235    where
236        T: ?Sized + serde::Serialize,
237    {
238        self.serializer.serialize_newtype_struct(name, value)
239    }
240
241    fn serialize_newtype_variant<T>(
242        self,
243        name: &'static str,
244        variant_index: u32,
245        variant: &'static str,
246        value: &T,
247    ) -> Result<Self::Ok, Self::Error>
248    where
249        T: ?Sized + serde::Serialize,
250    {
251        if self.key {
252            Err(Error::InvalidKey)
253        } else {
254            self.serializer
255                .serialize_newtype_variant(name, variant_index, variant, value)
256        }
257    }
258
259    fn serialize_seq(self, len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> {
260        if self.key {
261            Err(Error::InvalidKey)
262        } else {
263            self.serializer.serialize_seq(len)
264        }
265    }
266
267    fn serialize_tuple(self, len: usize) -> Result<Self::SerializeTuple, Self::Error> {
268        if self.key {
269            Err(Error::InvalidKey)
270        } else {
271            self.serializer.serialize_tuple(len)
272        }
273    }
274
275    fn serialize_tuple_struct(
276        self,
277        name: &'static str,
278        len: usize,
279    ) -> Result<Self::SerializeTupleStruct, Self::Error> {
280        if self.key {
281            Err(Error::InvalidKey)
282        } else {
283            self.serializer.serialize_tuple_struct(name, len)
284        }
285    }
286
287    fn serialize_tuple_variant(
288        self,
289        name: &'static str,
290        variant_index: u32,
291        variant: &'static str,
292        len: usize,
293    ) -> Result<Self::SerializeTupleVariant, Self::Error> {
294        if self.key {
295            Err(Error::InvalidKey)
296        } else {
297            self.serializer
298                .serialize_tuple_variant(name, variant_index, variant, len)
299        }
300    }
301
302    fn serialize_map(self, len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> {
303        if self.key {
304            Err(Error::InvalidKey)
305        } else {
306            self.serializer.serialize_map(len)
307        }
308    }
309
310    fn serialize_struct(
311        self,
312        name: &'static str,
313        len: usize,
314    ) -> Result<Self::SerializeStruct, Self::Error> {
315        if self.key {
316            Err(Error::InvalidKey)
317        } else {
318            self.serializer.serialize_struct(name, len)
319        }
320    }
321
322    fn serialize_struct_variant(
323        self,
324        name: &'static str,
325        variant_index: u32,
326        variant: &'static str,
327        len: usize,
328    ) -> Result<Self::SerializeStructVariant, Self::Error> {
329        if self.key {
330            Err(Error::InvalidKey)
331        } else {
332            self.serializer
333                .serialize_struct_variant(name, variant_index, variant, len)
334        }
335    }
336}