1use messagepack_core::{Format, extension::ExtensionRef as CoreExtensionRef, io::IoWrite};
2use serde::{
3 Serialize, Serializer,
4 de::Visitor,
5 ser::{self, SerializeSeq},
6};
7
8use crate::ser::Error;
9
10pub(crate) const EXTENSION_STRUCT_NAME: &str = "$__MSGPACK_EXTENSION_STRUCT";
11
12pub(crate) struct SerializeExt<'a, W> {
13 writer: &'a mut W,
14 length: usize,
15}
16
17impl<W> AsMut<Self> for SerializeExt<'_, W> {
18 fn as_mut(&mut self) -> &mut Self {
19 self
20 }
21}
22
23impl<'a, W> SerializeExt<'a, W> {
24 pub fn new(writer: &'a mut W) -> Self {
25 Self { writer, length: 0 }
26 }
27
28 pub(crate) fn length(&self) -> usize {
29 self.length
30 }
31}
32
33impl<W: IoWrite> SerializeExt<'_, W> {
34 fn unexpected(&self) -> Error<W::Error> {
35 ser::Error::custom("unexpected value")
36 }
37}
38
39impl<'a, 'b, W> ser::Serializer for &'a mut SerializeExt<'b, W>
40where
41 'b: 'a,
42 W: IoWrite,
43{
44 type Ok = ();
45
46 type Error = Error<W::Error>;
47
48 type SerializeSeq = SerializeExtSeq<'a, 'b, W>;
49
50 type SerializeTuple = serde::ser::Impossible<Self::Ok, Self::Error>;
51
52 type SerializeTupleStruct = serde::ser::Impossible<Self::Ok, Self::Error>;
53
54 type SerializeTupleVariant = serde::ser::Impossible<Self::Ok, Self::Error>;
55
56 type SerializeMap = serde::ser::Impossible<Self::Ok, Self::Error>;
57
58 type SerializeStruct = serde::ser::Impossible<Self::Ok, Self::Error>;
59
60 type SerializeStructVariant = serde::ser::Impossible<Self::Ok, Self::Error>;
61
62 fn serialize_bool(self, _: bool) -> Result<Self::Ok, Self::Error> {
63 Err(self.unexpected())
64 }
65
66 fn serialize_i8(self, v: i8) -> Result<Self::Ok, Self::Error> {
67 self.serialize_bytes(&v.to_be_bytes())
68 }
69
70 fn serialize_i16(self, _v: i16) -> Result<Self::Ok, Self::Error> {
71 Err(self.unexpected())
72 }
73
74 fn serialize_i32(self, _v: i32) -> Result<Self::Ok, Self::Error> {
75 Err(self.unexpected())
76 }
77
78 fn serialize_i64(self, _v: i64) -> Result<Self::Ok, Self::Error> {
79 Err(self.unexpected())
80 }
81
82 fn serialize_u8(self, v: u8) -> Result<Self::Ok, Self::Error> {
83 self.serialize_bytes(&v.to_be_bytes())
84 }
85
86 fn serialize_u16(self, v: u16) -> Result<Self::Ok, Self::Error> {
87 self.serialize_bytes(&v.to_be_bytes())
88 }
89
90 fn serialize_u32(self, v: u32) -> Result<Self::Ok, Self::Error> {
91 self.serialize_bytes(&v.to_be_bytes())
92 }
93
94 fn serialize_u64(self, v: u64) -> Result<Self::Ok, Self::Error> {
95 self.serialize_bytes(&v.to_be_bytes())
96 }
97
98 fn serialize_f32(self, _v: f32) -> Result<Self::Ok, Self::Error> {
99 Err(self.unexpected())
100 }
101
102 fn serialize_f64(self, _v: f64) -> Result<Self::Ok, Self::Error> {
103 Err(self.unexpected())
104 }
105
106 fn serialize_char(self, _v: char) -> Result<Self::Ok, Self::Error> {
107 Err(self.unexpected())
108 }
109
110 fn serialize_str(self, _v: &str) -> Result<Self::Ok, Self::Error> {
111 Err(self.unexpected())
112 }
113
114 fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok, Self::Error> {
115 self.writer
116 .write(v)
117 .map_err(messagepack_core::encode::Error::Io)?;
118 self.length += v.len();
119 Ok(())
120 }
121
122 fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
123 Err(self.unexpected())
124 }
125
126 fn serialize_some<T>(self, _value: &T) -> Result<Self::Ok, Self::Error>
127 where
128 T: ?Sized + Serialize,
129 {
130 Err(self.unexpected())
131 }
132
133 fn serialize_unit(self) -> Result<Self::Ok, Self::Error> {
134 Err(self.unexpected())
135 }
136
137 fn serialize_unit_struct(self, _name: &'static str) -> Result<Self::Ok, Self::Error> {
138 Err(self.unexpected())
139 }
140
141 fn serialize_unit_variant(
142 self,
143 _name: &'static str,
144 _variant_index: u32,
145 _variant: &'static str,
146 ) -> Result<Self::Ok, Self::Error> {
147 Err(self.unexpected())
148 }
149
150 fn serialize_newtype_struct<T>(
151 self,
152 _name: &'static str,
153 value: &T,
154 ) -> Result<Self::Ok, Self::Error>
155 where
156 T: ?Sized + Serialize,
157 {
158 value.serialize(self)
159 }
160
161 fn serialize_newtype_variant<T>(
162 self,
163 _name: &'static str,
164 _variant_index: u32,
165 _variant: &'static str,
166 _value: &T,
167 ) -> Result<Self::Ok, Self::Error>
168 where
169 T: ?Sized + Serialize,
170 {
171 Err(self.unexpected())
172 }
173
174 fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> {
175 Ok(SerializeExtSeq::new(self))
176 }
177
178 fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple, Self::Error> {
179 Err(self.unexpected())
180 }
181
182 fn serialize_tuple_struct(
183 self,
184 _name: &'static str,
185 _len: usize,
186 ) -> Result<Self::SerializeTupleStruct, Self::Error> {
187 Err(self.unexpected())
188 }
189
190 fn serialize_tuple_variant(
191 self,
192 _name: &'static str,
193 _variant_index: u32,
194 _variant: &'static str,
195 _len: usize,
196 ) -> Result<Self::SerializeTupleVariant, Self::Error> {
197 Err(self.unexpected())
198 }
199
200 fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> {
201 Err(self.unexpected())
202 }
203
204 fn serialize_struct(
205 self,
206 _name: &'static str,
207 _len: usize,
208 ) -> Result<Self::SerializeStruct, Self::Error> {
209 Err(self.unexpected())
210 }
211
212 fn serialize_struct_variant(
213 self,
214 _name: &'static str,
215 _variant_index: u32,
216 _variant: &'static str,
217 _len: usize,
218 ) -> Result<Self::SerializeStructVariant, Self::Error> {
219 Err(self.unexpected())
220 }
221
222 fn collect_str<T>(self, _value: &T) -> Result<Self::Ok, Self::Error>
223 where
224 T: ?Sized + core::fmt::Display,
225 {
226 Err(self.unexpected())
227 }
228}
229
230pub struct SerializeExtSeq<'a, 'b, W> {
231 ser: &'a mut SerializeExt<'b, W>,
232}
233
234impl<'a, 'b, W> SerializeExtSeq<'a, 'b, W> {
235 pub(crate) fn new(ser: &'a mut SerializeExt<'b, W>) -> Self {
236 Self { ser }
237 }
238}
239
240impl<'a, 'b, W> ser::SerializeSeq for SerializeExtSeq<'a, 'b, W>
241where
242 'b: 'a,
243 W: IoWrite,
244{
245 type Ok = ();
246 type Error = Error<W::Error>;
247 fn serialize_element<T>(&mut self, value: &T) -> Result<(), Self::Error>
248 where
249 T: ?Sized + Serialize,
250 {
251 value.serialize(self.ser.as_mut())
252 }
253 fn end(self) -> Result<Self::Ok, Self::Error> {
254 Ok(())
255 }
256}
257
258struct Bytes<'a>(pub &'a [u8]);
259impl ser::Serialize for Bytes<'_> {
260 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
261 where
262 S: Serializer,
263 {
264 serializer.serialize_bytes(self.0)
265 }
266}
267
268struct ExtInner<'a> {
269 kind: i8,
270 data: &'a [u8],
271}
272
273impl ser::Serialize for ExtInner<'_> {
274 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
275 where
276 S: Serializer,
277 {
278 let encoder = CoreExtensionRef::new(self.kind, self.data);
279 let format = encoder
280 .to_format::<core::convert::Infallible>()
281 .map_err(|_| ser::Error::custom("Invalid data length"))?;
282
283 let mut seq = serializer.serialize_seq(Some(4))?;
284
285 seq.serialize_element(&Bytes(&format.as_slice()))?;
286
287 match format {
288 messagepack_core::Format::FixExt1
289 | messagepack_core::Format::FixExt2
290 | messagepack_core::Format::FixExt4
291 | messagepack_core::Format::FixExt8
292 | messagepack_core::Format::FixExt16 => {}
293
294 messagepack_core::Format::Ext8 => {
295 let len = (self.data.len() as u8).to_be_bytes();
296 seq.serialize_element(&Bytes(&len))?;
297 }
298 messagepack_core::Format::Ext16 => {
299 let len = (self.data.len() as u16).to_be_bytes();
300 seq.serialize_element(&Bytes(&len))?;
301 }
302 messagepack_core::Format::Ext32 => {
303 let len = (self.data.len() as u32).to_be_bytes();
304 seq.serialize_element(&Bytes(&len))?;
305 }
306 _ => return Err(ser::Error::custom("unexpected format")),
307 };
308 seq.serialize_element(&Bytes(&self.kind.to_be_bytes()))?;
309 seq.serialize_element(&Bytes(self.data))?;
310
311 seq.end()
312 }
313}
314
315pub(crate) struct DeserializeExt<'de> {
316 data_len: usize,
317 pub(crate) input: &'de [u8],
318}
319
320impl AsMut<Self> for DeserializeExt<'_> {
321 fn as_mut(&mut self) -> &mut Self {
322 self
323 }
324}
325
326impl<'de> DeserializeExt<'de> {
327 pub(crate) fn new(format: Format, input: &'de [u8]) -> Result<Self, crate::de::Error> {
328 let (data_len, rest) = match format {
329 Format::FixExt1 => (1, input),
330 Format::FixExt2 => (2, input),
331 Format::FixExt4 => (4, input),
332 Format::FixExt8 => (8, input),
333 Format::FixExt16 => (16, input),
334 Format::Ext8 => {
335 let (first, rest) = input
336 .split_first_chunk::<1>()
337 .ok_or(messagepack_core::decode::Error::EofData)?;
338 let val = u8::from_be_bytes(*first);
339 (val.into(), rest)
340 }
341 Format::Ext16 => {
342 let (first, rest) = input
343 .split_first_chunk::<2>()
344 .ok_or(messagepack_core::decode::Error::EofData)?;
345 let val = u16::from_be_bytes(*first);
346 (val.into(), rest)
347 }
348 Format::Ext32 => {
349 let (first, rest) = input
350 .split_first_chunk::<4>()
351 .ok_or(messagepack_core::decode::Error::EofData)?;
352 let val = u32::from_be_bytes(*first);
353 (val as usize, rest)
354 }
355 _ => return Err(messagepack_core::decode::Error::UnexpectedFormat.into()),
356 };
357 Ok(DeserializeExt {
358 data_len,
359 input: rest,
360 })
361 }
362}
363
364impl<'de> serde::Deserializer<'de> for &mut DeserializeExt<'de> {
365 type Error = crate::de::Error;
366
367 fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
368 where
369 V: Visitor<'de>,
370 {
371 Err(serde::de::Error::custom(
372 "any when deserialize extension is not supported",
373 ))
374 }
375
376 fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
377 where
378 V: Visitor<'de>,
379 {
380 let (first, rest) = self
381 .input
382 .split_first_chunk::<1>()
383 .ok_or(messagepack_core::decode::Error::EofData)?;
384
385 let val = i8::from_be_bytes(*first);
386 self.input = rest;
387 visitor.visit_i8(val)
388 }
389
390 fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
391 where
392 V: Visitor<'de>,
393 {
394 let (data, rest) = self
395 .input
396 .split_at_checked(self.data_len)
397 .ok_or(messagepack_core::decode::Error::EofData)?;
398 self.input = rest;
399 visitor.visit_borrowed_bytes(data)
400 }
401
402 fn deserialize_seq<V>(mut self, visitor: V) -> Result<V::Value, Self::Error>
403 where
404 V: Visitor<'de>,
405 {
406 visitor.visit_seq(&mut self)
407 }
408
409 fn deserialize_newtype_struct<V>(
410 self,
411 _name: &'static str,
412 visitor: V,
413 ) -> Result<V::Value, Self::Error>
414 where
415 V: Visitor<'de>,
416 {
417 visitor.visit_newtype_struct(self)
418 }
419
420 serde::forward_to_deserialize_any! {
421 bool i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
422 byte_buf option unit unit_struct tuple
423 tuple_struct map struct enum identifier ignored_any
424 }
425}
426
427impl<'de> serde::de::SeqAccess<'de> for &mut DeserializeExt<'de> {
428 type Error = crate::de::Error;
429 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
430 where
431 T: serde::de::DeserializeSeed<'de>,
432 {
433 seed.deserialize(self.as_mut()).map(Some)
434 }
435}
436
437pub mod ext_ref {
465 use super::*;
466
467 pub fn serialize<S>(
469 ext: &messagepack_core::extension::ExtensionRef<'_>,
470 serializer: S,
471 ) -> Result<S::Ok, S::Error>
472 where
473 S: serde::Serializer,
474 {
475 serializer.serialize_newtype_struct(
476 EXTENSION_STRUCT_NAME,
477 &ExtInner {
478 kind: ext.r#type,
479 data: ext.data,
480 },
481 )
482 }
483
484 pub fn deserialize<'de, D>(
486 deserializer: D,
487 ) -> Result<messagepack_core::extension::ExtensionRef<'de>, D::Error>
488 where
489 D: serde::Deserializer<'de>,
490 {
491 struct ExtensionVisitor;
492
493 impl<'de> Visitor<'de> for ExtensionVisitor {
494 type Value = messagepack_core::extension::ExtensionRef<'de>;
495 fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
496 formatter.write_str("expect extension")
497 }
498
499 fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
500 where
501 D: serde::Deserializer<'de>,
502 {
503 deserializer.deserialize_seq(self)
504 }
505
506 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
507 where
508 A: serde::de::SeqAccess<'de>,
509 {
510 let kind = seq
511 .next_element::<i8>()?
512 .ok_or(serde::de::Error::custom("expect i8"))?;
513
514 let data = seq
515 .next_element::<&[u8]>()?
516 .ok_or(serde::de::Error::custom("expect [u8]"))?;
517
518 Ok(messagepack_core::extension::ExtensionRef::new(kind, data))
519 }
520 }
521 deserializer.deserialize_seq(ExtensionVisitor)
522 }
523}
524
525pub mod ext_fixed {
553 use serde::de;
554
555 pub fn serialize<const N: usize, S>(
557 ext: &messagepack_core::extension::FixedExtension<N>,
558 serializer: S,
559 ) -> Result<S::Ok, S::Error>
560 where
561 S: serde::Serializer,
562 {
563 super::ext_ref::serialize(&ext.as_ref(), serializer)
564 }
565
566 pub fn deserialize<'de, const N: usize, D>(
568 deserializer: D,
569 ) -> Result<messagepack_core::extension::FixedExtension<N>, D::Error>
570 where
571 D: serde::Deserializer<'de>,
572 {
573 let r = super::ext_ref::deserialize(deserializer)?;
574
575 let ext = messagepack_core::extension::FixedExtension::new(r.r#type, r.data)
576 .ok_or_else(|| de::Error::custom("extension length is too long"))?;
577 Ok(ext)
578 }
579}
580
581#[cfg(test)]
582mod tests {
583 use super::*;
584 use messagepack_core::extension::{ExtensionRef, FixedExtension};
585 use rstest::rstest;
586 use serde::{Deserialize, Serialize};
587
588 #[derive(Debug, Serialize, Deserialize)]
589 struct WrapRef<'a>(
590 #[serde(with = "ext_ref", borrow)] messagepack_core::extension::ExtensionRef<'a>,
591 );
592
593 #[rstest]
594 fn encode_ext_ref() {
595 let mut buf = [0_u8; 3];
596
597 let kind: i8 = 123;
598
599 let ext = WrapRef(ExtensionRef::new(kind, &[0x12]));
600 let length = crate::to_slice(&ext, &mut buf).unwrap();
601
602 assert_eq!(length, 3);
603 assert_eq!(buf, [0xd4, kind.to_be_bytes()[0], 0x12]);
604 }
605
606 #[rstest]
607 fn decode_ext_ref() {
608 let buf = [0xd6, 0xff, 0x00, 0x00, 0x00, 0x00]; let ext = crate::from_slice::<WrapRef<'_>>(&buf).unwrap().0;
611 assert_eq!(ext.r#type, -1);
612 let seconds = u32::from_be_bytes(ext.data.try_into().unwrap());
613 assert_eq!(seconds, 0);
614 }
615
616 #[derive(Debug, Serialize, Deserialize)]
617 struct WrapFixed<const N: usize>(
618 #[serde(with = "ext_fixed")] messagepack_core::extension::FixedExtension<N>,
619 );
620
621 #[rstest]
622 fn encode_ext_fixed() {
623 let mut buf = [0u8; 3];
624 let kind: i8 = 123;
625
626 let ext = WrapFixed(FixedExtension::new_fixed(kind, [0x12]));
627 let length = crate::to_slice(&ext, &mut buf).unwrap();
628
629 assert_eq!(length, 3);
630 assert_eq!(buf, [0xd4, kind.to_be_bytes()[0], 0x12]);
631 }
632
633 const TIMESTAMP32: &[u8] = &[0xd6, 0xff, 0x00, 0x00, 0x00, 0x00];
634
635 #[rstest]
636 fn decode_ext_fixed_bigger_will_success() {
637 let ext = crate::from_slice::<WrapFixed<6>>(TIMESTAMP32).unwrap().0;
638 assert_eq!(ext.r#type, -1);
639 assert_eq!(ext.data(), &TIMESTAMP32[2..])
640 }
641
642 #[rstest]
643 #[should_panic]
644 fn decode_ext_fixed_smaller_will_failed() {
645 let _ = crate::from_slice::<WrapFixed<3>>(TIMESTAMP32).unwrap();
646 }
647}