Skip to main content

facet_msgpack/
serializer.rs

1//! MsgPack serializer implementing FormatSerializer.
2
3extern crate alloc;
4
5use alloc::{string::String, vec::Vec};
6use core::fmt::Write as _;
7
8use facet_format::{FormatSerializer, ScalarValue, SerializeError};
9
10/// MsgPack serializer error.
11#[derive(Debug)]
12pub struct MsgPackSerializeError {
13    message: String,
14}
15
16impl core::fmt::Display for MsgPackSerializeError {
17    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
18        f.write_str(&self.message)
19    }
20}
21
22impl std::error::Error for MsgPackSerializeError {}
23
24/// MsgPack serializer.
25pub struct MsgPackSerializer {
26    out: Vec<u8>,
27    /// Stack tracking whether we're in a struct or sequence, and item counts
28    stack: Vec<ContainerState>,
29}
30
31#[derive(Debug)]
32enum ContainerState {
33    Struct { count: usize, count_pos: usize },
34    Seq { count: usize, count_pos: usize },
35}
36
37impl MsgPackSerializer {
38    /// Create a new MsgPack serializer.
39    pub const fn new() -> Self {
40        Self {
41            out: Vec::new(),
42            stack: Vec::new(),
43        }
44    }
45
46    /// Consume the serializer and return the output bytes.
47    pub fn finish(mut self) -> Vec<u8> {
48        // Patch up any remaining container counts (shouldn't happen with well-formed input)
49        while let Some(state) = self.stack.pop() {
50            match state {
51                ContainerState::Struct { count, count_pos } => {
52                    self.patch_map_count(count_pos, count);
53                }
54                ContainerState::Seq { count, count_pos } => {
55                    self.patch_array_count(count_pos, count);
56                }
57            }
58        }
59        self.out
60    }
61
62    fn write_nil(&mut self) {
63        self.out.push(0xc0);
64    }
65
66    fn write_bool(&mut self, v: bool) {
67        self.out.push(if v { 0xc3 } else { 0xc2 });
68    }
69
70    fn write_u64(&mut self, n: u64) {
71        match n {
72            0..=127 => {
73                // positive fixint
74                self.out.push(n as u8);
75            }
76            128..=255 => {
77                // uint8
78                self.out.push(0xcc);
79                self.out.push(n as u8);
80            }
81            256..=65535 => {
82                // uint16
83                self.out.push(0xcd);
84                self.out.extend_from_slice(&(n as u16).to_be_bytes());
85            }
86            65536..=4294967295 => {
87                // uint32
88                self.out.push(0xce);
89                self.out.extend_from_slice(&(n as u32).to_be_bytes());
90            }
91            _ => {
92                // uint64
93                self.out.push(0xcf);
94                self.out.extend_from_slice(&n.to_be_bytes());
95            }
96        }
97    }
98
99    fn write_i64(&mut self, n: i64) {
100        match n {
101            // Positive range - use unsigned encoding
102            0..=i64::MAX => self.write_u64(n as u64),
103            // Negative fixint (-32 to -1)
104            -32..=-1 => {
105                self.out.push(n as u8);
106            }
107            // int8 (-128 to -33)
108            -128..=-33 => {
109                self.out.push(0xd0);
110                self.out.push(n as u8);
111            }
112            // int16
113            -32768..=-129 => {
114                self.out.push(0xd1);
115                self.out.extend_from_slice(&(n as i16).to_be_bytes());
116            }
117            // int32
118            -2147483648..=-32769 => {
119                self.out.push(0xd2);
120                self.out.extend_from_slice(&(n as i32).to_be_bytes());
121            }
122            // int64
123            _ => {
124                self.out.push(0xd3);
125                self.out.extend_from_slice(&n.to_be_bytes());
126            }
127        }
128    }
129
130    fn write_f64(&mut self, n: f64) {
131        self.out.push(0xcb);
132        self.out.extend_from_slice(&n.to_be_bytes());
133    }
134
135    fn write_str(&mut self, s: &str) {
136        let bytes = s.as_bytes();
137        let len = bytes.len();
138
139        match len {
140            0..=31 => {
141                // fixstr
142                self.out.push(0xa0 | len as u8);
143            }
144            32..=255 => {
145                // str8
146                self.out.push(0xd9);
147                self.out.push(len as u8);
148            }
149            256..=65535 => {
150                // str16
151                self.out.push(0xda);
152                self.out.extend_from_slice(&(len as u16).to_be_bytes());
153            }
154            _ => {
155                // str32
156                self.out.push(0xdb);
157                self.out.extend_from_slice(&(len as u32).to_be_bytes());
158            }
159        }
160        self.out.extend_from_slice(bytes);
161    }
162
163    fn write_bin(&mut self, bytes: &[u8]) {
164        let len = bytes.len();
165
166        match len {
167            0..=255 => {
168                // bin8
169                self.out.push(0xc4);
170                self.out.push(len as u8);
171            }
172            256..=65535 => {
173                // bin16
174                self.out.push(0xc5);
175                self.out.extend_from_slice(&(len as u16).to_be_bytes());
176            }
177            _ => {
178                // bin32
179                self.out.push(0xc6);
180                self.out.extend_from_slice(&(len as u32).to_be_bytes());
181            }
182        }
183        self.out.extend_from_slice(bytes);
184    }
185
186    /// Write a map header with placeholder count, return position of count.
187    fn begin_map(&mut self) -> usize {
188        // Use map32 format for flexibility (we'll patch it later)
189        // Actually, we'll use fixmap initially and upgrade if needed
190        let count_pos = self.out.len();
191        self.out.push(0x80); // fixmap with 0 elements (placeholder)
192        count_pos
193    }
194
195    fn patch_map_count(&mut self, count_pos: usize, count: usize) {
196        match count {
197            0..=15 => {
198                // fixmap - just update the byte
199                self.out[count_pos] = 0x80 | count as u8;
200            }
201            16..=65535 => {
202                // Need to convert to map16
203                // First, save everything after the placeholder
204                let tail = self.out[count_pos + 1..].to_vec();
205                self.out.truncate(count_pos);
206                self.out.push(0xde); // map16
207                self.out.extend_from_slice(&(count as u16).to_be_bytes());
208                self.out.extend_from_slice(&tail);
209            }
210            _ => {
211                // Need to convert to map32
212                let tail = self.out[count_pos + 1..].to_vec();
213                self.out.truncate(count_pos);
214                self.out.push(0xdf); // map32
215                self.out.extend_from_slice(&(count as u32).to_be_bytes());
216                self.out.extend_from_slice(&tail);
217            }
218        }
219    }
220
221    /// Write an array header with placeholder count, return position of count.
222    fn begin_array(&mut self) -> usize {
223        let count_pos = self.out.len();
224        self.out.push(0x90); // fixarray with 0 elements (placeholder)
225        count_pos
226    }
227
228    fn patch_array_count(&mut self, count_pos: usize, count: usize) {
229        match count {
230            0..=15 => {
231                // fixarray - just update the byte
232                self.out[count_pos] = 0x90 | count as u8;
233            }
234            16..=65535 => {
235                // Need to convert to array16
236                let tail = self.out[count_pos + 1..].to_vec();
237                self.out.truncate(count_pos);
238                self.out.push(0xdc); // array16
239                self.out.extend_from_slice(&(count as u16).to_be_bytes());
240                self.out.extend_from_slice(&tail);
241            }
242            _ => {
243                // Need to convert to array32
244                let tail = self.out[count_pos + 1..].to_vec();
245                self.out.truncate(count_pos);
246                self.out.push(0xdd); // array32
247                self.out.extend_from_slice(&(count as u32).to_be_bytes());
248                self.out.extend_from_slice(&tail);
249            }
250        }
251    }
252
253    /// Record a value emission in the current sequence, if any.
254    fn bump_seq_count_for_value(&mut self) {
255        if let Some(ContainerState::Seq { count, .. }) = self.stack.last_mut() {
256            *count += 1;
257        }
258    }
259}
260
261impl Default for MsgPackSerializer {
262    fn default() -> Self {
263        Self::new()
264    }
265}
266
267impl FormatSerializer for MsgPackSerializer {
268    type Error = MsgPackSerializeError;
269
270    fn begin_struct(&mut self) -> Result<(), Self::Error> {
271        self.bump_seq_count_for_value();
272        let count_pos = self.begin_map();
273        self.stack.push(ContainerState::Struct {
274            count: 0,
275            count_pos,
276        });
277        Ok(())
278    }
279
280    fn field_key(&mut self, key: &str) -> Result<(), Self::Error> {
281        // Increment count in current struct
282        if let Some(ContainerState::Struct { count, .. }) = self.stack.last_mut() {
283            *count += 1;
284        }
285        self.write_str(key);
286        Ok(())
287    }
288
289    fn end_struct(&mut self) -> Result<(), Self::Error> {
290        match self.stack.pop() {
291            Some(ContainerState::Struct { count, count_pos }) => {
292                self.patch_map_count(count_pos, count);
293                Ok(())
294            }
295            _ => Err(MsgPackSerializeError {
296                message: "end_struct called without matching begin_struct".into(),
297            }),
298        }
299    }
300
301    fn begin_seq(&mut self) -> Result<(), Self::Error> {
302        self.bump_seq_count_for_value();
303        let count_pos = self.begin_array();
304        self.stack.push(ContainerState::Seq {
305            count: 0,
306            count_pos,
307        });
308        Ok(())
309    }
310
311    fn end_seq(&mut self) -> Result<(), Self::Error> {
312        match self.stack.pop() {
313            Some(ContainerState::Seq { count, count_pos }) => {
314                self.patch_array_count(count_pos, count);
315                Ok(())
316            }
317            _ => Err(MsgPackSerializeError {
318                message: "end_seq called without matching begin_seq".into(),
319            }),
320        }
321    }
322
323    fn scalar(&mut self, scalar: ScalarValue<'_>) -> Result<(), Self::Error> {
324        self.bump_seq_count_for_value();
325
326        match scalar {
327            ScalarValue::Null | ScalarValue::Unit => self.write_nil(),
328            ScalarValue::Bool(v) => self.write_bool(v),
329            ScalarValue::Char(c) => {
330                let mut buf = [0u8; 4];
331                self.write_str(c.encode_utf8(&mut buf));
332            }
333            ScalarValue::U64(n) => self.write_u64(n),
334            ScalarValue::I64(n) => self.write_i64(n),
335            ScalarValue::U128(n) => {
336                // MsgPack doesn't natively support u128, serialize as string
337                let mut buf = String::new();
338                write!(buf, "{}", n).unwrap();
339                self.write_str(&buf);
340            }
341            ScalarValue::I128(n) => {
342                // MsgPack doesn't natively support i128, serialize as string
343                let mut buf = String::new();
344                write!(buf, "{}", n).unwrap();
345                self.write_str(&buf);
346            }
347            ScalarValue::F64(n) => self.write_f64(n),
348            ScalarValue::Str(s) => self.write_str(&s),
349            ScalarValue::Bytes(bytes) => self.write_bin(&bytes),
350        }
351        Ok(())
352    }
353
354    fn is_self_describing(&self) -> bool {
355        false
356    }
357}
358
359/// Serialize a value to MsgPack bytes.
360pub fn to_vec<'facet, T>(value: &T) -> Result<Vec<u8>, SerializeError<MsgPackSerializeError>>
361where
362    T: facet_core::Facet<'facet>,
363{
364    let mut ser = MsgPackSerializer::new();
365    facet_format::serialize_root(&mut ser, facet_reflect::Peek::new(value))?;
366    Ok(ser.finish())
367}
368
369/// Serialize a value to MsgPack bytes using a writer.
370pub fn to_writer<'facet, T, W>(writer: &mut W, value: &T) -> Result<(), std::io::Error>
371where
372    T: facet_core::Facet<'facet>,
373    W: std::io::Write,
374{
375    let bytes = to_vec(value).map_err(|e| std::io::Error::other(e.to_string()))?;
376    writer.write_all(&bytes)
377}