facet_msgpack/
serializer.rs

1//! MsgPack serializer implementing FormatSerializer.
2
3extern crate alloc;
4
5use alloc::{string::String, vec::Vec};
6use core::fmt::Write as _;
7
8use facet_format::{FormatSerializer, ScalarValue, SerializeError};
9
10/// MsgPack serializer error.
11#[derive(Debug)]
12pub struct MsgPackSerializeError {
13    message: String,
14}
15
16impl core::fmt::Display for MsgPackSerializeError {
17    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
18        f.write_str(&self.message)
19    }
20}
21
22impl std::error::Error for MsgPackSerializeError {}
23
24/// MsgPack serializer.
25pub struct MsgPackSerializer {
26    out: Vec<u8>,
27    /// Stack tracking whether we're in a struct or sequence, and item counts
28    stack: Vec<ContainerState>,
29}
30
31#[derive(Debug)]
32enum ContainerState {
33    Struct { count: usize, count_pos: usize },
34    Seq { count: usize, count_pos: usize },
35}
36
37impl MsgPackSerializer {
38    /// Create a new MsgPack serializer.
39    pub fn new() -> Self {
40        Self {
41            out: Vec::new(),
42            stack: Vec::new(),
43        }
44    }
45
46    /// Consume the serializer and return the output bytes.
47    pub fn finish(mut self) -> Vec<u8> {
48        // Patch up any remaining container counts (shouldn't happen with well-formed input)
49        while let Some(state) = self.stack.pop() {
50            match state {
51                ContainerState::Struct { count, count_pos } => {
52                    self.patch_map_count(count_pos, count);
53                }
54                ContainerState::Seq { count, count_pos } => {
55                    self.patch_array_count(count_pos, count);
56                }
57            }
58        }
59        self.out
60    }
61
62    fn write_nil(&mut self) {
63        self.out.push(0xc0);
64    }
65
66    fn write_bool(&mut self, v: bool) {
67        self.out.push(if v { 0xc3 } else { 0xc2 });
68    }
69
70    fn write_u64(&mut self, n: u64) {
71        match n {
72            0..=127 => {
73                // positive fixint
74                self.out.push(n as u8);
75            }
76            128..=255 => {
77                // uint8
78                self.out.push(0xcc);
79                self.out.push(n as u8);
80            }
81            256..=65535 => {
82                // uint16
83                self.out.push(0xcd);
84                self.out.extend_from_slice(&(n as u16).to_be_bytes());
85            }
86            65536..=4294967295 => {
87                // uint32
88                self.out.push(0xce);
89                self.out.extend_from_slice(&(n as u32).to_be_bytes());
90            }
91            _ => {
92                // uint64
93                self.out.push(0xcf);
94                self.out.extend_from_slice(&n.to_be_bytes());
95            }
96        }
97    }
98
99    fn write_i64(&mut self, n: i64) {
100        match n {
101            // Positive range - use unsigned encoding
102            0..=i64::MAX => self.write_u64(n as u64),
103            // Negative fixint (-32 to -1)
104            -32..=-1 => {
105                self.out.push(n as u8);
106            }
107            // int8 (-128 to -33)
108            -128..=-33 => {
109                self.out.push(0xd0);
110                self.out.push(n as u8);
111            }
112            // int16
113            -32768..=-129 => {
114                self.out.push(0xd1);
115                self.out.extend_from_slice(&(n as i16).to_be_bytes());
116            }
117            // int32
118            -2147483648..=-32769 => {
119                self.out.push(0xd2);
120                self.out.extend_from_slice(&(n as i32).to_be_bytes());
121            }
122            // int64
123            _ => {
124                self.out.push(0xd3);
125                self.out.extend_from_slice(&n.to_be_bytes());
126            }
127        }
128    }
129
130    fn write_f64(&mut self, n: f64) {
131        self.out.push(0xcb);
132        self.out.extend_from_slice(&n.to_be_bytes());
133    }
134
135    fn write_str(&mut self, s: &str) {
136        let bytes = s.as_bytes();
137        let len = bytes.len();
138
139        match len {
140            0..=31 => {
141                // fixstr
142                self.out.push(0xa0 | len as u8);
143            }
144            32..=255 => {
145                // str8
146                self.out.push(0xd9);
147                self.out.push(len as u8);
148            }
149            256..=65535 => {
150                // str16
151                self.out.push(0xda);
152                self.out.extend_from_slice(&(len as u16).to_be_bytes());
153            }
154            _ => {
155                // str32
156                self.out.push(0xdb);
157                self.out.extend_from_slice(&(len as u32).to_be_bytes());
158            }
159        }
160        self.out.extend_from_slice(bytes);
161    }
162
163    fn write_bin(&mut self, bytes: &[u8]) {
164        let len = bytes.len();
165
166        match len {
167            0..=255 => {
168                // bin8
169                self.out.push(0xc4);
170                self.out.push(len as u8);
171            }
172            256..=65535 => {
173                // bin16
174                self.out.push(0xc5);
175                self.out.extend_from_slice(&(len as u16).to_be_bytes());
176            }
177            _ => {
178                // bin32
179                self.out.push(0xc6);
180                self.out.extend_from_slice(&(len as u32).to_be_bytes());
181            }
182        }
183        self.out.extend_from_slice(bytes);
184    }
185
186    /// Write a map header with placeholder count, return position of count.
187    fn begin_map(&mut self) -> usize {
188        // Use map32 format for flexibility (we'll patch it later)
189        // Actually, we'll use fixmap initially and upgrade if needed
190        let count_pos = self.out.len();
191        self.out.push(0x80); // fixmap with 0 elements (placeholder)
192        count_pos
193    }
194
195    fn patch_map_count(&mut self, count_pos: usize, count: usize) {
196        match count {
197            0..=15 => {
198                // fixmap - just update the byte
199                self.out[count_pos] = 0x80 | count as u8;
200            }
201            16..=65535 => {
202                // Need to convert to map16
203                // First, save everything after the placeholder
204                let tail = self.out[count_pos + 1..].to_vec();
205                self.out.truncate(count_pos);
206                self.out.push(0xde); // map16
207                self.out.extend_from_slice(&(count as u16).to_be_bytes());
208                self.out.extend_from_slice(&tail);
209            }
210            _ => {
211                // Need to convert to map32
212                let tail = self.out[count_pos + 1..].to_vec();
213                self.out.truncate(count_pos);
214                self.out.push(0xdf); // map32
215                self.out.extend_from_slice(&(count as u32).to_be_bytes());
216                self.out.extend_from_slice(&tail);
217            }
218        }
219    }
220
221    /// Write an array header with placeholder count, return position of count.
222    fn begin_array(&mut self) -> usize {
223        let count_pos = self.out.len();
224        self.out.push(0x90); // fixarray with 0 elements (placeholder)
225        count_pos
226    }
227
228    fn patch_array_count(&mut self, count_pos: usize, count: usize) {
229        match count {
230            0..=15 => {
231                // fixarray - just update the byte
232                self.out[count_pos] = 0x90 | count as u8;
233            }
234            16..=65535 => {
235                // Need to convert to array16
236                let tail = self.out[count_pos + 1..].to_vec();
237                self.out.truncate(count_pos);
238                self.out.push(0xdc); // array16
239                self.out.extend_from_slice(&(count as u16).to_be_bytes());
240                self.out.extend_from_slice(&tail);
241            }
242            _ => {
243                // Need to convert to array32
244                let tail = self.out[count_pos + 1..].to_vec();
245                self.out.truncate(count_pos);
246                self.out.push(0xdd); // array32
247                self.out.extend_from_slice(&(count as u32).to_be_bytes());
248                self.out.extend_from_slice(&tail);
249            }
250        }
251    }
252}
253
254impl Default for MsgPackSerializer {
255    fn default() -> Self {
256        Self::new()
257    }
258}
259
260impl FormatSerializer for MsgPackSerializer {
261    type Error = MsgPackSerializeError;
262
263    fn begin_struct(&mut self) -> Result<(), Self::Error> {
264        let count_pos = self.begin_map();
265        self.stack.push(ContainerState::Struct {
266            count: 0,
267            count_pos,
268        });
269        Ok(())
270    }
271
272    fn field_key(&mut self, key: &str) -> Result<(), Self::Error> {
273        // Increment count in current struct
274        if let Some(ContainerState::Struct { count, .. }) = self.stack.last_mut() {
275            *count += 1;
276        }
277        self.write_str(key);
278        Ok(())
279    }
280
281    fn end_struct(&mut self) -> Result<(), Self::Error> {
282        match self.stack.pop() {
283            Some(ContainerState::Struct { count, count_pos }) => {
284                self.patch_map_count(count_pos, count);
285                Ok(())
286            }
287            _ => Err(MsgPackSerializeError {
288                message: "end_struct called without matching begin_struct".into(),
289            }),
290        }
291    }
292
293    fn begin_seq(&mut self) -> Result<(), Self::Error> {
294        let count_pos = self.begin_array();
295        self.stack.push(ContainerState::Seq {
296            count: 0,
297            count_pos,
298        });
299        Ok(())
300    }
301
302    fn end_seq(&mut self) -> Result<(), Self::Error> {
303        match self.stack.pop() {
304            Some(ContainerState::Seq { count, count_pos }) => {
305                self.patch_array_count(count_pos, count);
306                Ok(())
307            }
308            _ => Err(MsgPackSerializeError {
309                message: "end_seq called without matching begin_seq".into(),
310            }),
311        }
312    }
313
314    fn scalar(&mut self, scalar: ScalarValue<'_>) -> Result<(), Self::Error> {
315        // Increment count in current sequence
316        if let Some(ContainerState::Seq { count, .. }) = self.stack.last_mut() {
317            *count += 1;
318        }
319
320        match scalar {
321            ScalarValue::Null => self.write_nil(),
322            ScalarValue::Bool(v) => self.write_bool(v),
323            ScalarValue::U64(n) => self.write_u64(n),
324            ScalarValue::I64(n) => self.write_i64(n),
325            ScalarValue::U128(n) => {
326                // MsgPack doesn't natively support u128, serialize as string
327                let mut buf = String::new();
328                write!(buf, "{}", n).unwrap();
329                self.write_str(&buf);
330            }
331            ScalarValue::I128(n) => {
332                // MsgPack doesn't natively support i128, serialize as string
333                let mut buf = String::new();
334                write!(buf, "{}", n).unwrap();
335                self.write_str(&buf);
336            }
337            ScalarValue::F64(n) => self.write_f64(n),
338            ScalarValue::Str(s) => self.write_str(&s),
339            ScalarValue::Bytes(bytes) => self.write_bin(&bytes),
340        }
341        Ok(())
342    }
343}
344
345/// Serialize a value to MsgPack bytes.
346pub fn to_vec<'facet, T>(value: &T) -> Result<Vec<u8>, SerializeError<MsgPackSerializeError>>
347where
348    T: facet_core::Facet<'facet>,
349{
350    let mut ser = MsgPackSerializer::new();
351    facet_format::serialize_root(&mut ser, facet_reflect::Peek::new(value))?;
352    Ok(ser.finish())
353}
354
355/// Serialize a value to MsgPack bytes using a writer.
356pub fn to_writer<'facet, T, W>(writer: &mut W, value: &T) -> Result<(), std::io::Error>
357where
358    T: facet_core::Facet<'facet>,
359    W: std::io::Write,
360{
361    let bytes = to_vec(value).map_err(|e| std::io::Error::other(e.to_string()))?;
362    writer.write_all(&bytes)
363}