protosocket_messagepack/
lib.rs

1use std::{io::Read, marker::PhantomData};
2
3#[derive(Debug)]
4pub struct ProtosocketMessagePackSerializer<T> {
5    _phantom: std::marker::PhantomData<T>,
6}
7
8impl<T> Default for ProtosocketMessagePackSerializer<T> {
9    fn default() -> Self {
10        Self {
11            _phantom: PhantomData,
12        }
13    }
14}
15
16impl<T> protosocket::Serializer for ProtosocketMessagePackSerializer<T>
17where
18    T: serde::Serialize + Send + Unpin + std::fmt::Debug,
19{
20    type Message = T;
21
22    fn encode(&mut self, message: Self::Message, buffer: &mut Vec<u8>) {
23        log::debug!("encoding {message:?}");
24        // reserve length prefix
25        buffer.extend_from_slice(&[0; 5]);
26        rmp_serde::encode::write(buffer, &message).expect("messages must be encodable");
27        let len = buffer.len();
28        unsafe {
29            buffer.set_len(0);
30        }
31        rmp::encode::write_u32(buffer, len as u32 - 5).expect("message length is encodable");
32        unsafe {
33            buffer.set_len(len);
34        }
35    }
36}
37
38#[derive(Debug)]
39pub struct ProtosocketMessagePackDeserializer<T> {
40    _phantom: std::marker::PhantomData<T>,
41    state: State,
42}
43
44impl<T> Default for ProtosocketMessagePackDeserializer<T> {
45    fn default() -> Self {
46        Self {
47            _phantom: PhantomData,
48            state: Default::default(),
49        }
50    }
51}
52
53#[derive(Debug, Default, Copy, Clone)]
54enum State {
55    #[default]
56    Waiting,
57    ReadingLength(u32),
58}
59
60impl<T> protosocket::Deserializer for ProtosocketMessagePackDeserializer<T>
61where
62    T: serde::de::DeserializeOwned + Send + Unpin + std::fmt::Debug,
63{
64    type Message = T;
65
66    fn decode(
67        &mut self,
68        buffer: impl bytes::Buf,
69    ) -> std::result::Result<(usize, Self::Message), protosocket::DeserializeError> {
70        let start_remaining = buffer.remaining();
71        let mut reader = buffer.reader();
72        let length = match self.state {
73            State::Waiting => {
74                // 1 byte for the number tag, 4 bytes for the message length
75                if start_remaining < 5 {
76                    return Err(protosocket::DeserializeError::IncompleteBuffer {
77                        next_message_size: 5,
78                    });
79                }
80                let length: u32 = match rmp::decode::read_u32(&mut reader) {
81                    Ok(length) => length,
82                    Err(e) => {
83                        log::error!("decode length error: {e:?}");
84                        return Err(protosocket::DeserializeError::InvalidBuffer);
85                    }
86                };
87                self.state = State::ReadingLength(length);
88                length
89            }
90            State::ReadingLength(length) => {
91                let _ = reader.read(&mut [0; 5]).expect("skip parsing");
92                length
93            }
94        };
95        if start_remaining < (length + 5) as usize {
96            return Err(protosocket::DeserializeError::IncompleteBuffer {
97                next_message_size: (length + 5) as usize,
98            });
99        }
100        self.state = State::Waiting;
101
102        rmp_serde::decode::from_read(&mut reader)
103            .map_err(|e| {
104                log::error!("decode error length {length}: {e:?}");
105                protosocket::DeserializeError::InvalidBuffer
106            })
107            .map(|message| {
108                let buffer = reader.into_inner();
109                let length = start_remaining - buffer.remaining();
110                log::debug!("decoded {length}: {message:?}");
111                (length, message)
112            })
113    }
114}