Skip to main content

protosocket_prost/
prost_serializer.rs

1use std::marker::PhantomData;
2
3use protosocket::{Decoder, DeserializeError, Serialize};
4
5/// A stateless implementation of `Serialize` using `prost`
6#[derive(Default, Debug)]
7pub struct ProstSerializer<Message> {
8    _phantom: PhantomData<Message>,
9}
10
11impl<Message> Serialize for ProstSerializer<Message>
12where
13    Message: prost::Message + std::fmt::Debug,
14{
15    type Message = Message;
16
17    fn serialize_into_buffer(&mut self, message: Self::Message, buffer: &mut Vec<u8>) {
18        match message.encode_length_delimited(buffer) {
19            Ok(_) => {
20                log::debug!("encoded {message:?}");
21            }
22            Err(e) => {
23                log::error!("encoding error: {e:?}");
24            }
25        }
26    }
27}
28
29/// A stateless implementation of `Decoder` using `prost`
30#[derive(Debug, Default)]
31pub struct ProstDecoder<Message> {
32    _phantom: PhantomData<Message>,
33}
34impl<Message> Decoder for ProstDecoder<Message>
35where
36    Message: prost::Message + Default + std::fmt::Debug,
37{
38    type Message = Message;
39
40    fn decode(
41        &mut self,
42        mut buffer: impl bytes::Buf,
43    ) -> std::result::Result<(usize, Self::Message), DeserializeError> {
44        match prost::decode_length_delimiter(buffer.chunk()) {
45            Ok(message_length) => {
46                if buffer.remaining() < message_length + prost::length_delimiter_len(message_length)
47                {
48                    return Err(DeserializeError::IncompleteBuffer {
49                        next_message_size: message_length,
50                    });
51                }
52            }
53            Err(e) => {
54                log::trace!("can't read a length delimiter {e:?}");
55                return Err(DeserializeError::IncompleteBuffer {
56                    next_message_size: 10,
57                });
58            }
59        };
60
61        let start = buffer.remaining();
62        match <Self::Message as prost::Message>::decode_length_delimited(&mut buffer) {
63            Ok(message) => {
64                let length = start - buffer.remaining();
65                log::debug!("decoded {length}: {message:?}");
66                Ok((length, message))
67            }
68            Err(e) => {
69                log::warn!("could not decode message: {e:?}");
70                Err(DeserializeError::InvalidBuffer)
71            }
72        }
73    }
74}