1use std::io;
2
3use asynchronous_codec::{BytesMut, Decoder, Encoder};
4use unsigned_varint::codec::UviBytes;
5
6pub struct ProtobufUviCodec<M> {
7 uvi_codec: UviBytes,
8 _marker: std::marker::PhantomData<M>,
9}
10
11impl<M> Default for ProtobufUviCodec<M> {
12 fn default() -> Self {
13 ProtobufUviCodec {
14 uvi_codec: UviBytes::default(),
15 _marker: std::marker::PhantomData,
16 }
17 }
18}
19
20impl<M> Clone for ProtobufUviCodec<M> {
21 fn clone(&self) -> Self {
22 let mut uvi = UviBytes::default();
23 uvi.set_max_len(self.uvi_codec.max_len());
24 ProtobufUviCodec {
25 uvi_codec: uvi,
26 _marker: std::marker::PhantomData,
27 }
28 }
29}
30
31impl<M> ProtobufUviCodec<M> {
32 pub fn new(max_size: usize) -> Self {
33 Self::default().set_max_len(max_size)
34 }
35
36 pub fn set_max_len(mut self, val: usize) -> Self {
37 self.uvi_codec.set_max_len(val);
38 self
39 }
40
41 pub fn max_len(&self) -> usize {
42 self.uvi_codec.max_len()
43 }
44}
45
46impl<M> Encoder for ProtobufUviCodec<M>
47where
48 M: prost::Message,
49{
50 type Item<'a> = M;
51 type Error = io::Error;
52
53 fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
54 let len = item.encoded_len();
55 let mut buffer = BytesMut::with_capacity(len);
56 item.encode(&mut buffer)?;
57 self.uvi_codec.encode(buffer.freeze(), dst)
58 }
59}
60
61impl<M> Decoder for ProtobufUviCodec<M>
62where
63 M: prost::Message + Default,
64{
65 type Item = M;
66 type Error = io::Error;
67
68 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
69 match self.uvi_codec.decode(src) {
70 Ok(Some(bytes)) => {
71 let item = M::decode(bytes.as_ref())
72 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
73 Ok(Some(item))
74 }
75 Ok(None) => Ok(None),
76 Err(e) => Err(io::Error::new(io::ErrorKind::InvalidData, e)),
77 }
78 }
79}