messagepack_serde/extension/
mod.rs1pub(crate) mod de;
4pub(crate) mod ser;
5
6use serde::{Serialize, Serializer, de::Visitor};
7pub(crate) const EXTENSION_STRUCT_NAME: &str = "$__MSGPACK_EXTENSION_STRUCT";
8
9#[cfg(feature = "alloc")]
10mod owned;
11#[cfg(feature = "alloc")]
12pub use owned::ext_owned;
13
14mod timestamp;
15pub use timestamp::{timestamp32, timestamp64, timestamp96};
16
17struct Bytes<'a>(pub &'a [u8]);
18impl Serialize for Bytes<'_> {
19 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
20 where
21 S: Serializer,
22 {
23 serializer.serialize_bytes(self.0)
24 }
25}
26
27struct ExtInner<'a> {
28 kind: i8,
29 data: &'a [u8],
30}
31
32impl Serialize for ExtInner<'_> {
33 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
34 where
35 S: Serializer,
36 {
37 use messagepack_core::extension::ExtensionRef;
38 use serde::ser::{self, SerializeSeq};
39 let encoder = ExtensionRef::new(self.kind, self.data);
40 let format = encoder
41 .to_format::<core::convert::Infallible>()
42 .map_err(|_| ser::Error::custom("Invalid data length"))?;
43
44 let mut seq = serializer.serialize_seq(None)?;
45
46 seq.serialize_element(&Bytes(&format.as_slice()))?;
47
48 match format {
49 messagepack_core::Format::FixExt1
50 | messagepack_core::Format::FixExt2
51 | messagepack_core::Format::FixExt4
52 | messagepack_core::Format::FixExt8
53 | messagepack_core::Format::FixExt16 => {}
54
55 messagepack_core::Format::Ext8 => {
56 let len = self.data.len() as u8;
57 seq.serialize_element(&len)?;
58 }
59 messagepack_core::Format::Ext16 => {
60 let len = self.data.len() as u16;
61 seq.serialize_element(&len)?;
62 }
63 messagepack_core::Format::Ext32 => {
64 let len = self.data.len() as u32;
65 seq.serialize_element(&len)?;
66 }
67 _ => return Err(ser::Error::custom("unexpected format")),
68 };
69 seq.serialize_element(&self.kind)?;
70 seq.serialize_element(&Bytes(self.data))?;
71
72 seq.end()
73 }
74}
75
76pub mod ext_ref {
104 use super::*;
105 use serde::de;
106
107 pub fn serialize<S>(
109 ext: &messagepack_core::extension::ExtensionRef<'_>,
110 serializer: S,
111 ) -> Result<S::Ok, S::Error>
112 where
113 S: serde::Serializer,
114 {
115 serializer.serialize_newtype_struct(
116 EXTENSION_STRUCT_NAME,
117 &ExtInner {
118 kind: ext.r#type,
119 data: ext.data,
120 },
121 )
122 }
123
124 pub fn deserialize<'de, D>(
126 deserializer: D,
127 ) -> Result<messagepack_core::extension::ExtensionRef<'de>, D::Error>
128 where
129 D: serde::Deserializer<'de>,
130 {
131 struct ExtensionVisitor;
132
133 impl<'de> Visitor<'de> for ExtensionVisitor {
134 type Value = messagepack_core::extension::ExtensionRef<'de>;
135 fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
136 formatter.write_str("expect extension")
137 }
138
139 fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
140 where
141 D: de::Deserializer<'de>,
142 {
143 deserializer.deserialize_seq(self)
144 }
145
146 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
147 where
148 A: serde::de::SeqAccess<'de>,
149 {
150 let kind = seq
151 .next_element::<i8>()?
152 .ok_or(de::Error::missing_field("extension type missing"))?;
153
154 let data = seq
155 .next_element::<&[u8]>()?
156 .ok_or(de::Error::missing_field("extension data missing"))?;
157
158 Ok(messagepack_core::extension::ExtensionRef::new(kind, data))
159 }
160 }
161 deserializer.deserialize_seq(ExtensionVisitor)
162 }
163}
164
165pub mod ext_fixed {
193 use super::*;
194 use serde::{Deserialize, de};
195
196 pub fn serialize<const N: usize, S>(
198 ext: &messagepack_core::extension::FixedExtension<N>,
199 serializer: S,
200 ) -> Result<S::Ok, S::Error>
201 where
202 S: serde::Serializer,
203 {
204 super::ext_ref::serialize(&ext.as_ref(), serializer)
205 }
206
207 pub fn deserialize<'de, const N: usize, D>(
209 deserializer: D,
210 ) -> Result<messagepack_core::extension::FixedExtension<N>, D::Error>
211 where
212 D: serde::Deserializer<'de>,
213 {
214 struct Data<const N: usize> {
215 len: usize,
216 buf: [u8; N],
217 }
218 impl<'de, const N: usize> Deserialize<'de> for Data<N> {
219 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
220 where
221 D: de::Deserializer<'de>,
222 {
223 struct DataVisitor<const N: usize>;
224 impl<'de, const N: usize> Visitor<'de> for DataVisitor<N> {
225 type Value = Data<N>;
226 fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
227 formatter.write_str("expect extension")
228 }
229
230 fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
231 where
232 E: de::Error,
233 {
234 let len = v.len();
235
236 if len > N {
237 return Err(de::Error::invalid_length(len, &self));
238 }
239
240 let mut buf = [0; N];
241 buf[..len].copy_from_slice(v);
242 Ok(Data { len, buf })
243 }
244 }
245 deserializer.deserialize_bytes(DataVisitor)
246 }
247 }
248
249 struct ExtensionVisitor<const N: usize>;
250 impl<'de, const N: usize> Visitor<'de> for ExtensionVisitor<N> {
251 type Value = messagepack_core::extension::FixedExtension<N>;
252 fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
253 formatter.write_str("expect extension")
254 }
255
256 fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
257 where
258 D: de::Deserializer<'de>,
259 {
260 deserializer.deserialize_seq(self)
261 }
262
263 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
264 where
265 A: de::SeqAccess<'de>,
266 {
267 let kind = seq
268 .next_element::<i8>()?
269 .ok_or(serde::de::Error::missing_field("extension type missing"))?;
270 let data = seq
271 .next_element::<Data<N>>()?
272 .ok_or(de::Error::missing_field("extension data missing"))?;
273
274 let ext = messagepack_core::extension::FixedExtension::new_fixed_with_prefix(
275 kind, data.len, data.buf,
276 )
277 .map_err(|_| de::Error::invalid_length(data.len, &"length is too long"))?;
278 Ok(ext)
279 }
280 }
281
282 deserializer.deserialize_seq(ExtensionVisitor)
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289
290 use messagepack_core::extension::{ExtensionRef, FixedExtension};
291 use rstest::rstest;
292 use serde::{Deserialize, Serialize};
293
294 #[derive(Debug, Serialize, Deserialize)]
295 struct WrapRef<'a>(
296 #[serde(with = "ext_ref", borrow)] messagepack_core::extension::ExtensionRef<'a>,
297 );
298
299 #[rstest]
300 fn encode_ext_ref() {
301 let mut buf = [0_u8; 3];
302
303 let kind: i8 = 123;
304
305 let ext = WrapRef(ExtensionRef::new(kind, &[0x12]));
306 let length = crate::to_slice(&ext, &mut buf).unwrap();
307
308 assert_eq!(length, 3);
309 assert_eq!(buf, [0xd4, kind.to_be_bytes()[0], 0x12]);
310 }
311
312 #[rstest]
313 fn decode_ext_ref() {
314 let buf = [0xd6, 0xff, 0x00, 0x00, 0x00, 0x00]; let ext = crate::from_slice::<WrapRef<'_>>(&buf).unwrap().0;
317 assert_eq!(ext.r#type, -1);
318 let seconds = u32::from_be_bytes(ext.data.try_into().unwrap());
319 assert_eq!(seconds, 0);
320 }
321
322 #[derive(Debug, Serialize, Deserialize)]
323 struct WrapFixed<const N: usize>(
324 #[serde(with = "ext_fixed")] messagepack_core::extension::FixedExtension<N>,
325 );
326
327 #[rstest]
328 fn encode_ext_fixed() {
329 let mut buf = [0u8; 3];
330 let kind: i8 = 123;
331
332 let ext = WrapFixed(FixedExtension::new_fixed(kind, [0x12]));
333 let length = crate::to_slice(&ext, &mut buf).unwrap();
334
335 assert_eq!(length, 3);
336 assert_eq!(buf, [0xd4, kind.to_be_bytes()[0], 0x12]);
337 }
338
339 const TIMESTAMP32: &[u8] = &[0xd6, 0xff, 0x00, 0x00, 0x00, 0x00];
340
341 #[rstest]
342 fn decode_ext_fixed_bigger_will_success() {
343 let ext = crate::from_slice::<WrapFixed<6>>(TIMESTAMP32).unwrap().0;
344 assert_eq!(ext.r#type, -1);
345 assert_eq!(ext.as_slice(), &TIMESTAMP32[2..])
346 }
347
348 #[rstest]
349 #[should_panic]
350 fn decode_ext_fixed_smaller_will_failed() {
351 let _ = crate::from_slice::<WrapFixed<3>>(TIMESTAMP32).unwrap();
352 }
353}