protosocket_prost/
prost_serializer.rs

1use std::marker::PhantomData;
2
3use protosocket::{DeserializeError, Deserializer, Serializer};
4
5/// A stateless implementation of protosocket's `Serializer` and `Deserializer`
6/// traits using `prost` for encoding and decoding protocol buffers messages.
7#[derive(Default, Debug)]
8pub struct ProstSerializer<Deserialized, Serialized> {
9    pub(crate) _phantom: PhantomData<(Deserialized, Serialized)>,
10}
11
12impl<Deserialized, Serialized> Serializer for ProstSerializer<Deserialized, Serialized>
13where
14    Deserialized: prost::Message + Default + Unpin,
15    Serialized: prost::Message + Unpin,
16{
17    type Message = Serialized;
18
19    fn encode(&mut self, message: Self::Message, buffer: &mut Vec<u8>) {
20        match message.encode_length_delimited(buffer) {
21            Ok(_) => {
22                log::debug!("encoded {message:?}");
23            }
24            Err(e) => {
25                log::error!("encoding error: {e:?}");
26            }
27        }
28    }
29}
30impl<Deserialized, Serialized> Deserializer for ProstSerializer<Deserialized, Serialized>
31where
32    Deserialized: prost::Message + Default + Unpin,
33    Serialized: prost::Message + Unpin,
34{
35    type Message = Deserialized;
36
37    fn decode(
38        &mut self,
39        mut buffer: impl bytes::Buf,
40    ) -> std::result::Result<(usize, Self::Message), DeserializeError> {
41        match prost::decode_length_delimiter(buffer.chunk()) {
42            Ok(message_length) => {
43                if buffer.remaining() < message_length + prost::length_delimiter_len(message_length)
44                {
45                    return Err(DeserializeError::IncompleteBuffer {
46                        next_message_size: message_length,
47                    });
48                }
49            }
50            Err(e) => {
51                log::trace!("can't read a length delimiter {e:?}");
52                return Err(DeserializeError::IncompleteBuffer {
53                    next_message_size: 10,
54                });
55            }
56        };
57
58        let start = buffer.remaining();
59        match <Self::Message as prost::Message>::decode_length_delimited(&mut buffer) {
60            Ok(message) => {
61                let length = start - buffer.remaining();
62                log::debug!("decoded {length}: {message:?}");
63                Ok((length, message))
64            }
65            Err(e) => {
66                log::warn!("could not decode message: {e:?}");
67                Err(DeserializeError::InvalidBuffer)
68            }
69        }
70    }
71}