protosocket_messagepack/
decoder.rs

1use std::io::Read;
2use std::marker::PhantomData;
3
4/// A deserializer that takes a serde deserializable T and implements
5/// `protosocket::Decoder`. You can use this with a `protosocket`
6/// Connection or rpc.
7#[derive(Debug)]
8pub struct ProtosocketMessagePackDecoder<T> {
9    _phantom: std::marker::PhantomData<T>,
10    state: State,
11}
12
13impl<T> Default for ProtosocketMessagePackDecoder<T> {
14    fn default() -> Self {
15        Self {
16            _phantom: PhantomData,
17            state: Default::default(),
18        }
19    }
20}
21
22#[derive(Debug, Default, Copy, Clone)]
23enum State {
24    #[default]
25    Waiting,
26    ReadingLength(u32),
27}
28
29impl<T> protosocket::Decoder for ProtosocketMessagePackDecoder<T>
30where
31    T: serde::de::DeserializeOwned + std::fmt::Debug,
32{
33    type Message = T;
34
35    fn decode(
36        &mut self,
37        buffer: impl bytes::Buf,
38    ) -> std::result::Result<(usize, Self::Message), protosocket::DeserializeError> {
39        let start_remaining = buffer.remaining();
40        let mut reader = buffer.reader();
41        let length = match self.state {
42            State::Waiting => {
43                // 1 byte for the number tag, 4 bytes for the message length
44                if start_remaining < 5 {
45                    return Err(protosocket::DeserializeError::IncompleteBuffer {
46                        next_message_size: 5,
47                    });
48                }
49                let length: u32 = match rmp::decode::read_u32(&mut reader) {
50                    Ok(length) => length,
51                    Err(e) => {
52                        log::error!("decode length error: {e:?}");
53                        return Err(protosocket::DeserializeError::InvalidBuffer);
54                    }
55                };
56                self.state = State::ReadingLength(length);
57                length
58            }
59            State::ReadingLength(length) => {
60                let _ = reader.read(&mut [0; 5]).expect("skip parsing");
61                length
62            }
63        };
64        if start_remaining < (length + 5) as usize {
65            return Err(protosocket::DeserializeError::IncompleteBuffer {
66                next_message_size: (length + 5) as usize,
67            });
68        }
69        self.state = State::Waiting;
70
71        rmp_serde::decode::from_read(&mut reader)
72            .map_err(|e| {
73                log::error!("decode error length {length}: {e:?}");
74                protosocket::DeserializeError::InvalidBuffer
75            })
76            .map(|message| {
77                let buffer = reader.into_inner();
78                let length = start_remaining - buffer.remaining();
79                log::debug!("decoded {length}: {message:?}");
80                (length, message)
81            })
82    }
83}