protosocket_messagepack/
lib.rs1use std::{io::Read, marker::PhantomData};
2
3#[derive(Debug)]
4pub struct ProtosocketMessagePackSerializer<T> {
5 _phantom: std::marker::PhantomData<T>,
6}
7
8impl<T> Default for ProtosocketMessagePackSerializer<T> {
9 fn default() -> Self {
10 Self {
11 _phantom: PhantomData,
12 }
13 }
14}
15
16impl<T> protosocket::Serializer for ProtosocketMessagePackSerializer<T>
17where
18 T: serde::Serialize + Send + Unpin + std::fmt::Debug,
19{
20 type Message = T;
21
22 fn encode(&mut self, message: Self::Message, buffer: &mut Vec<u8>) {
23 log::debug!("encoding {message:?}");
24 buffer.extend_from_slice(&[0; 5]);
26 rmp_serde::encode::write(buffer, &message).expect("messages must be encodable");
27 let len = buffer.len();
28 unsafe {
29 buffer.set_len(0);
30 }
31 rmp::encode::write_u32(buffer, len as u32 - 5).expect("message length is encodable");
32 unsafe {
33 buffer.set_len(len);
34 }
35 }
36}
37
38#[derive(Debug)]
39pub struct ProtosocketMessagePackDeserializer<T> {
40 _phantom: std::marker::PhantomData<T>,
41 state: State,
42}
43
44impl<T> Default for ProtosocketMessagePackDeserializer<T> {
45 fn default() -> Self {
46 Self {
47 _phantom: PhantomData,
48 state: Default::default(),
49 }
50 }
51}
52
53#[derive(Debug, Default, Copy, Clone)]
54enum State {
55 #[default]
56 Waiting,
57 ReadingLength(u32),
58}
59
60impl<T> protosocket::Deserializer for ProtosocketMessagePackDeserializer<T>
61where
62 T: serde::de::DeserializeOwned + Send + Unpin + std::fmt::Debug,
63{
64 type Message = T;
65
66 fn decode(
67 &mut self,
68 buffer: impl bytes::Buf,
69 ) -> std::result::Result<(usize, Self::Message), protosocket::DeserializeError> {
70 let start_remaining = buffer.remaining();
71 let mut reader = buffer.reader();
72 let length = match self.state {
73 State::Waiting => {
74 if start_remaining < 5 {
76 return Err(protosocket::DeserializeError::IncompleteBuffer {
77 next_message_size: 5,
78 });
79 }
80 let length: u32 = match rmp::decode::read_u32(&mut reader) {
81 Ok(length) => length,
82 Err(e) => {
83 log::error!("decode length error: {e:?}");
84 return Err(protosocket::DeserializeError::InvalidBuffer);
85 }
86 };
87 self.state = State::ReadingLength(length);
88 length
89 }
90 State::ReadingLength(length) => {
91 let _ = reader.read(&mut [0; 5]).expect("skip parsing");
92 length
93 }
94 };
95 if start_remaining < (length + 5) as usize {
96 return Err(protosocket::DeserializeError::IncompleteBuffer {
97 next_message_size: (length + 5) as usize,
98 });
99 }
100 self.state = State::Waiting;
101
102 rmp_serde::decode::from_read(&mut reader)
103 .map_err(|e| {
104 log::error!("decode error length {length}: {e:?}");
105 protosocket::DeserializeError::InvalidBuffer
106 })
107 .map(|message| {
108 let buffer = reader.into_inner();
109 let length = start_remaining - buffer.remaining();
110 log::debug!("decoded {length}: {message:?}");
111 (length, message)
112 })
113 }
114}