1use messagepack_core::{Format, encode::ExtensionEncoder, io::IoWrite};
2use serde::{
3 Deserialize, Serialize, Serializer,
4 de::Visitor,
5 ser::{self, SerializeSeq},
6};
7
8use crate::ser::{CoreError, Error};
9
10pub(crate) const EXTENSION_STRUCT_NAME: &str = "$__MSGPACK_EXTENSION_STRUCT";
11
12#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord)]
14pub struct ExtensionRef<'a> {
15 pub kind: i8,
16 pub data: &'a [u8],
17}
18
19impl<'a> ExtensionRef<'a> {
20 pub fn new(kind: i8, data: &'a [u8]) -> Self {
21 Self { kind, data }
22 }
23}
24
25pub(crate) struct SerializeExt<'a, W> {
26 writer: &'a mut W,
27 length: &'a mut usize,
28}
29
30impl<W> AsMut<Self> for SerializeExt<'_, W> {
31 fn as_mut(&mut self) -> &mut Self {
32 self
33 }
34}
35
36impl<'a, W> SerializeExt<'a, W> {
37 pub fn new(writer: &'a mut W, length: &'a mut usize) -> Self {
38 Self { writer, length }
39 }
40}
41
42impl<W: IoWrite> SerializeExt<'_, W> {
43 fn unexpected(&self) -> Error<W::Error> {
44 ser::Error::custom("unexpected value")
45 }
46}
47
48impl<'a, 'b, W> ser::Serializer for &'a mut SerializeExt<'b, W>
49where
50 'b: 'a,
51 W: IoWrite,
52{
53 type Ok = ();
54
55 type Error = Error<W::Error>;
56
57 type SerializeSeq = SerializeExtSeq<'a, 'b, W>;
58
59 type SerializeTuple = serde::ser::Impossible<Self::Ok, Self::Error>;
60
61 type SerializeTupleStruct = serde::ser::Impossible<Self::Ok, Self::Error>;
62
63 type SerializeTupleVariant = serde::ser::Impossible<Self::Ok, Self::Error>;
64
65 type SerializeMap = serde::ser::Impossible<Self::Ok, Self::Error>;
66
67 type SerializeStruct = serde::ser::Impossible<Self::Ok, Self::Error>;
68
69 type SerializeStructVariant = serde::ser::Impossible<Self::Ok, Self::Error>;
70
71 fn serialize_bool(self, _: bool) -> Result<Self::Ok, Self::Error> {
72 Err(self.unexpected())
73 }
74
75 fn serialize_i8(self, v: i8) -> Result<Self::Ok, Self::Error> {
76 self.serialize_bytes(&v.to_be_bytes())
77 }
78
79 fn serialize_i16(self, _v: i16) -> Result<Self::Ok, Self::Error> {
80 Err(self.unexpected())
81 }
82
83 fn serialize_i32(self, _v: i32) -> Result<Self::Ok, Self::Error> {
84 Err(self.unexpected())
85 }
86
87 fn serialize_i64(self, _v: i64) -> Result<Self::Ok, Self::Error> {
88 Err(self.unexpected())
89 }
90
91 fn serialize_u8(self, v: u8) -> Result<Self::Ok, Self::Error> {
92 self.serialize_bytes(&v.to_be_bytes())
93 }
94
95 fn serialize_u16(self, v: u16) -> Result<Self::Ok, Self::Error> {
96 self.serialize_bytes(&v.to_be_bytes())
97 }
98
99 fn serialize_u32(self, v: u32) -> Result<Self::Ok, Self::Error> {
100 self.serialize_bytes(&v.to_be_bytes())
101 }
102
103 fn serialize_u64(self, v: u64) -> Result<Self::Ok, Self::Error> {
104 self.serialize_bytes(&v.to_be_bytes())
105 }
106
107 fn serialize_f32(self, _v: f32) -> Result<Self::Ok, Self::Error> {
108 Err(self.unexpected())
109 }
110
111 fn serialize_f64(self, _v: f64) -> Result<Self::Ok, Self::Error> {
112 Err(self.unexpected())
113 }
114
115 fn serialize_char(self, _v: char) -> Result<Self::Ok, Self::Error> {
116 Err(self.unexpected())
117 }
118
119 fn serialize_str(self, _v: &str) -> Result<Self::Ok, Self::Error> {
120 Err(self.unexpected())
121 }
122
123 fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok, Self::Error> {
124 self.writer.write_bytes(v).map_err(CoreError::Io)?;
125 *self.length += v.len();
126 Ok(())
127 }
128
129 fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
130 Err(self.unexpected())
131 }
132
133 fn serialize_some<T>(self, _value: &T) -> Result<Self::Ok, Self::Error>
134 where
135 T: ?Sized + Serialize,
136 {
137 Err(self.unexpected())
138 }
139
140 fn serialize_unit(self) -> Result<Self::Ok, Self::Error> {
141 Err(self.unexpected())
142 }
143
144 fn serialize_unit_struct(self, _name: &'static str) -> Result<Self::Ok, Self::Error> {
145 Err(self.unexpected())
146 }
147
148 fn serialize_unit_variant(
149 self,
150 _name: &'static str,
151 _variant_index: u32,
152 _variant: &'static str,
153 ) -> Result<Self::Ok, Self::Error> {
154 Err(self.unexpected())
155 }
156
157 fn serialize_newtype_struct<T>(
158 self,
159 _name: &'static str,
160 value: &T,
161 ) -> Result<Self::Ok, Self::Error>
162 where
163 T: ?Sized + Serialize,
164 {
165 value.serialize(self)
166 }
167
168 fn serialize_newtype_variant<T>(
169 self,
170 _name: &'static str,
171 _variant_index: u32,
172 _variant: &'static str,
173 _value: &T,
174 ) -> Result<Self::Ok, Self::Error>
175 where
176 T: ?Sized + Serialize,
177 {
178 Err(self.unexpected())
179 }
180
181 fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> {
182 Ok(SerializeExtSeq::new(self))
183 }
184
185 fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple, Self::Error> {
186 Err(self.unexpected())
187 }
188
189 fn serialize_tuple_struct(
190 self,
191 _name: &'static str,
192 _len: usize,
193 ) -> Result<Self::SerializeTupleStruct, Self::Error> {
194 Err(self.unexpected())
195 }
196
197 fn serialize_tuple_variant(
198 self,
199 _name: &'static str,
200 _variant_index: u32,
201 _variant: &'static str,
202 _len: usize,
203 ) -> Result<Self::SerializeTupleVariant, Self::Error> {
204 Err(self.unexpected())
205 }
206
207 fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> {
208 Err(self.unexpected())
209 }
210
211 fn serialize_struct(
212 self,
213 _name: &'static str,
214 _len: usize,
215 ) -> Result<Self::SerializeStruct, Self::Error> {
216 Err(self.unexpected())
217 }
218
219 fn serialize_struct_variant(
220 self,
221 _name: &'static str,
222 _variant_index: u32,
223 _variant: &'static str,
224 _len: usize,
225 ) -> Result<Self::SerializeStructVariant, Self::Error> {
226 Err(self.unexpected())
227 }
228}
229
230pub struct SerializeExtSeq<'a, 'b, W> {
231 ser: &'a mut SerializeExt<'b, W>,
232}
233
234impl<'a, 'b, W> SerializeExtSeq<'a, 'b, W> {
235 pub(crate) fn new(ser: &'a mut SerializeExt<'b, W>) -> Self {
236 Self { ser }
237 }
238}
239
240impl<'a, 'b, W> ser::SerializeSeq for SerializeExtSeq<'a, 'b, W>
241where
242 'b: 'a,
243 W: IoWrite,
244{
245 type Ok = ();
246 type Error = Error<W::Error>;
247 fn serialize_element<T>(&mut self, value: &T) -> Result<(), Self::Error>
248 where
249 T: ?Sized + Serialize,
250 {
251 value.serialize(self.ser.as_mut())
252 }
253 fn end(self) -> Result<Self::Ok, Self::Error> {
254 Ok(())
255 }
256}
257
258struct ExtInner<'a> {
259 kind: i8,
260 data: &'a [u8],
261}
262
263impl ser::Serialize for ExtInner<'_> {
264 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
265 where
266 S: Serializer,
267 {
268 let encoder = ExtensionEncoder::new(self.kind, self.data);
269 let format = encoder
270 .to_format::<()>()
271 .map_err(|_| ser::Error::custom("Invalid data length"))?;
272
273 let mut seq = serializer.serialize_seq(Some(4))?;
274
275 seq.serialize_element(serde_bytes::Bytes::new(&format.as_slice()))?;
276
277 const EMPTY: &[u8] = &[];
278
279 match format {
280 messagepack_core::Format::FixExt1 => {
281 seq.serialize_element(serde_bytes::Bytes::new(EMPTY))
282 }
283 messagepack_core::Format::FixExt2 => {
284 seq.serialize_element(serde_bytes::Bytes::new(EMPTY))
285 }
286 messagepack_core::Format::FixExt4 => {
287 seq.serialize_element(serde_bytes::Bytes::new(EMPTY))
288 }
289 messagepack_core::Format::FixExt8 => {
290 seq.serialize_element(serde_bytes::Bytes::new(EMPTY))
291 }
292 messagepack_core::Format::FixExt16 => {
293 seq.serialize_element(serde_bytes::Bytes::new(EMPTY))
294 }
295 messagepack_core::Format::Ext8 => {
296 let len = (self.data.len() as u8).to_be_bytes();
297 seq.serialize_element(serde_bytes::Bytes::new(&len))
298 }
299 messagepack_core::Format::Ext16 => {
300 let len = (self.data.len() as u16).to_be_bytes();
301 seq.serialize_element(serde_bytes::Bytes::new(&len))
302 }
303 messagepack_core::Format::Ext32 => {
304 let len = (self.data.len() as u32).to_be_bytes();
305 seq.serialize_element(serde_bytes::Bytes::new(&len))
306 }
307 _ => unreachable!(),
308 }?;
309 seq.serialize_element(serde_bytes::Bytes::new(&self.kind.to_be_bytes()))?;
310 seq.serialize_element(serde_bytes::Bytes::new(self.data))?;
311
312 seq.end()
313 }
314}
315
316impl ser::Serialize for ExtensionRef<'_> {
317 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
318 where
319 S: Serializer,
320 {
321 serializer.serialize_newtype_struct(
322 EXTENSION_STRUCT_NAME,
323 &ExtInner {
324 kind: self.kind,
325 data: self.data,
326 },
327 )
328 }
329}
330
331pub(crate) struct DeserializeExt<'de> {
332 data_len: usize,
333 pub(crate) input: &'de [u8],
334}
335
336impl AsMut<Self> for DeserializeExt<'_> {
337 fn as_mut(&mut self) -> &mut Self {
338 self
339 }
340}
341
342impl<'de> DeserializeExt<'de> {
343 pub(crate) fn new(format: Format, input: &'de [u8]) -> Result<Self, crate::de::Error> {
344 let (data_len, rest) = match format {
345 Format::FixExt1 => (1, input),
346 Format::FixExt2 => (2, input),
347 Format::FixExt4 => (4, input),
348 Format::FixExt8 => (8, input),
349 Format::FixExt16 => (16, input),
350 Format::Ext8 => {
351 let (first, rest) = input
352 .split_first_chunk::<1>()
353 .ok_or(messagepack_core::decode::Error::EofData)?;
354 let val = u8::from_be_bytes(*first);
355 (val.into(), rest)
356 }
357 Format::Ext16 => {
358 let (first, rest) = input
359 .split_first_chunk::<2>()
360 .ok_or(messagepack_core::decode::Error::EofData)?;
361 let val = u16::from_be_bytes(*first);
362 (val.into(), rest)
363 }
364 Format::Ext32 => {
365 let (first, rest) = input
366 .split_first_chunk::<4>()
367 .ok_or(messagepack_core::decode::Error::EofData)?;
368 let val = u32::from_be_bytes(*first);
369 (val as usize, rest)
370 }
371 _ => return Err(messagepack_core::decode::Error::UnexpectedFormat.into()),
372 };
373 Ok(DeserializeExt {
374 data_len,
375 input: rest,
376 })
377 }
378}
379
380impl<'de> serde::Deserializer<'de> for &mut DeserializeExt<'de> {
381 type Error = crate::de::Error;
382
383 fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
384 where
385 V: Visitor<'de>,
386 {
387 Err(crate::de::Error::AnyIsUnsupported)
388 }
389
390 fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
391 where
392 V: Visitor<'de>,
393 {
394 let (first, rest) = self
395 .input
396 .split_first_chunk::<1>()
397 .ok_or(messagepack_core::decode::Error::EofData)?;
398
399 let val = i8::from_be_bytes(*first);
400 self.input = rest;
401 visitor.visit_i8(val)
402 }
403
404 fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
405 where
406 V: Visitor<'de>,
407 {
408 let (data, rest) = self
409 .input
410 .split_at_checked(self.data_len)
411 .ok_or(messagepack_core::decode::Error::EofData)?;
412 self.input = rest;
413 visitor.visit_borrowed_bytes(data)
414 }
415
416 fn deserialize_seq<V>(mut self, visitor: V) -> Result<V::Value, Self::Error>
417 where
418 V: Visitor<'de>,
419 {
420 visitor.visit_seq(&mut self)
421 }
422
423 fn deserialize_newtype_struct<V>(
424 self,
425 _name: &'static str,
426 visitor: V,
427 ) -> Result<V::Value, Self::Error>
428 where
429 V: Visitor<'de>,
430 {
431 visitor.visit_newtype_struct(self)
432 }
433
434 serde::forward_to_deserialize_any! {
435 bool i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
436 byte_buf option unit unit_struct tuple
437 tuple_struct map struct enum identifier ignored_any
438 }
439}
440
441impl<'de> serde::de::SeqAccess<'de> for &mut DeserializeExt<'de> {
442 type Error = crate::de::Error;
443 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
444 where
445 T: serde::de::DeserializeSeed<'de>,
446 {
447 seed.deserialize(self.as_mut()).map(Some)
448 }
449}
450
451impl<'de> Deserialize<'de> for ExtensionRef<'de> {
452 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
453 where
454 D: serde::Deserializer<'de>,
455 {
456 struct ExtensionVisitor;
457
458 impl<'de> Visitor<'de> for ExtensionVisitor {
459 type Value = ExtensionRef<'de>;
460 fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
461 formatter.write_str("expect extension")
462 }
463
464 fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
465 where
466 D: serde::Deserializer<'de>,
467 {
468 deserializer.deserialize_seq(self)
469 }
470
471 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
472 where
473 A: serde::de::SeqAccess<'de>,
474 {
475 let kind = seq
476 .next_element::<i8>()?
477 .ok_or(serde::de::Error::custom("expect i8"))?;
478
479 let data = seq
480 .next_element::<&[u8]>()?
481 .ok_or(serde::de::Error::custom("expect [u8]"))?;
482
483 Ok(ExtensionRef::new(kind, data))
484 }
485 }
486 deserializer.deserialize_any(ExtensionVisitor)
487 }
488}
489
490#[cfg(test)]
491mod tests {
492 use super::*;
493 use messagepack_core::SliceWriter;
494 use rstest::rstest;
495
496 #[rstest]
497 fn encode_ext() {
498 let mut buf = [0_u8; 3];
499 let mut writer = SliceWriter::from_slice(&mut buf);
500 let mut length = 0;
501 let mut ser = SerializeExt::new(&mut writer, &mut length);
502
503 let kind: i8 = 123;
504
505 let ext = ExtensionRef::new(kind, &[0x12]);
506
507 ext.serialize(&mut ser).unwrap();
508
509 assert_eq!(length, 3);
510 assert_eq!(buf, [0xd4, kind.to_be_bytes()[0], 0x12]);
511 }
512
513 #[rstest]
514 fn decode_ext() {
515 let buf = [0xd6, 0xff, 0x00, 0x00, 0x00, 0x00]; let ext = crate::from_slice::<ExtensionRef>(&buf).unwrap();
518 assert_eq!(ext.kind, -1);
519 let seconds = u32::from_be_bytes(ext.data.try_into().unwrap());
520 assert_eq!(seconds, 0);
521 }
522}