protosocket_messagepack/
decoder.rs1use std::io::Read;
2use std::marker::PhantomData;
3
4#[derive(Debug)]
8pub struct ProtosocketMessagePackDecoder<T> {
9 _phantom: std::marker::PhantomData<T>,
10 state: State,
11}
12
13impl<T> Default for ProtosocketMessagePackDecoder<T> {
14 fn default() -> Self {
15 Self {
16 _phantom: PhantomData,
17 state: Default::default(),
18 }
19 }
20}
21
22#[derive(Debug, Default, Copy, Clone)]
23enum State {
24 #[default]
25 Waiting,
26 ReadingLength(u32),
27}
28
29impl<T> protosocket::Decoder for ProtosocketMessagePackDecoder<T>
30where
31 T: serde::de::DeserializeOwned + std::fmt::Debug,
32{
33 type Message = T;
34
35 fn decode(
36 &mut self,
37 buffer: impl bytes::Buf,
38 ) -> std::result::Result<(usize, Self::Message), protosocket::DeserializeError> {
39 let start_remaining = buffer.remaining();
40 let mut reader = buffer.reader();
41 let length = match self.state {
42 State::Waiting => {
43 if start_remaining < 5 {
45 return Err(protosocket::DeserializeError::IncompleteBuffer {
46 next_message_size: 5,
47 });
48 }
49 let length: u32 = match rmp::decode::read_u32(&mut reader) {
50 Ok(length) => length,
51 Err(e) => {
52 log::error!("decode length error: {e:?}");
53 return Err(protosocket::DeserializeError::InvalidBuffer);
54 }
55 };
56 self.state = State::ReadingLength(length);
57 length
58 }
59 State::ReadingLength(length) => {
60 let _ = reader.read(&mut [0; 5]).expect("skip parsing");
61 length
62 }
63 };
64 if start_remaining < (length + 5) as usize {
65 return Err(protosocket::DeserializeError::IncompleteBuffer {
66 next_message_size: (length + 5) as usize,
67 });
68 }
69 self.state = State::Waiting;
70
71 rmp_serde::decode::from_read(&mut reader)
72 .map_err(|e| {
73 log::error!("decode error length {length}: {e:?}");
74 protosocket::DeserializeError::InvalidBuffer
75 })
76 .map(|message| {
77 let buffer = reader.into_inner();
78 let length = start_remaining - buffer.remaining();
79 log::debug!("decoded {length}: {message:?}");
80 (length, message)
81 })
82 }
83}