aldrin_core/
message_deserializer.rs

1use crate::buf_ext::MessageBufExt;
2use crate::message::MessageKind;
3use crate::serialized_value::SerializedValue;
4use bytes::{Buf, BytesMut};
5use thiserror::Error;
6use uuid::Uuid;
7
8pub(crate) struct MessageWithoutValueDeserializer {
9    buf: BytesMut,
10}
11
12impl MessageWithoutValueDeserializer {
13    pub fn new(mut buf: BytesMut, kind: MessageKind) -> Result<Self, MessageDeserializeError> {
14        let buf_len = buf.len();
15
16        // 4 bytes message length + 1 byte message kind.
17        if buf_len < 5 {
18            return Err(MessageDeserializeError::UnexpectedEoi);
19        }
20
21        let len = buf.get_u32_le() as usize;
22        if buf_len != len {
23            return Err(MessageDeserializeError::InvalidSerialization);
24        }
25
26        buf.ensure_discriminant_u8(kind)?;
27
28        Ok(Self { buf })
29    }
30
31    pub fn try_get_discriminant_u8<T: TryFrom<u8>>(
32        &mut self,
33    ) -> Result<T, MessageDeserializeError> {
34        self.buf.try_get_discriminant_u8()
35    }
36
37    pub fn try_get_varint_u32_le(&mut self) -> Result<u32, MessageDeserializeError> {
38        self.buf.try_get_varint_u32_le()
39    }
40
41    pub fn try_get_uuid(&mut self) -> Result<Uuid, MessageDeserializeError> {
42        let mut bytes = uuid::Bytes::default();
43        self.buf.try_copy_to_slice(&mut bytes)?;
44        Ok(Uuid::from_bytes(bytes))
45    }
46
47    pub fn finish(self) -> Result<(), MessageDeserializeError> {
48        if self.buf.is_empty() {
49            Ok(())
50        } else {
51            Err(MessageDeserializeError::TrailingData)
52        }
53    }
54}
55
56pub(crate) struct MessageWithValueDeserializer {
57    header_and_value: BytesMut,
58    msg: BytesMut,
59}
60
61impl MessageWithValueDeserializer {
62    pub fn new(mut buf: BytesMut, kind: MessageKind) -> Result<Self, MessageDeserializeError> {
63        debug_assert!(kind.has_value());
64
65        // 4 bytes message length + 1 byte message kind + 4 bytes value length + at least 1 byte
66        // value.
67        if buf.len() < 10 {
68            return Err(MessageDeserializeError::UnexpectedEoi);
69        }
70
71        let msg_len = (&buf[..4]).get_u32_le() as usize;
72        if buf.len() != msg_len {
73            return Err(MessageDeserializeError::InvalidSerialization);
74        }
75
76        if buf[4] != kind.into() {
77            return Err(MessageDeserializeError::UnexpectedMessage);
78        }
79
80        let value_len = (&buf[5..9]).get_u32_le() as usize;
81        let max_value_len = buf.len() - 9;
82
83        if value_len < 1 {
84            return Err(MessageDeserializeError::InvalidSerialization);
85        } else if value_len > max_value_len {
86            return Err(MessageDeserializeError::UnexpectedEoi);
87        }
88
89        let msg = buf.split_off(9 + value_len);
90        Ok(Self {
91            header_and_value: buf,
92            msg,
93        })
94    }
95
96    pub fn try_get_discriminant_u8<T: TryFrom<u8>>(
97        &mut self,
98    ) -> Result<T, MessageDeserializeError> {
99        self.msg.try_get_discriminant_u8()
100    }
101
102    pub fn try_get_varint_u32_le(&mut self) -> Result<u32, MessageDeserializeError> {
103        self.msg.try_get_varint_u32_le()
104    }
105
106    pub fn try_get_uuid(&mut self) -> Result<Uuid, MessageDeserializeError> {
107        let mut bytes = uuid::Bytes::default();
108        self.msg.try_copy_to_slice(&mut bytes)?;
109        Ok(Uuid::from_bytes(bytes))
110    }
111
112    pub fn finish(mut self) -> Result<SerializedValue, MessageDeserializeError> {
113        if self.msg.is_empty() {
114            self.header_and_value.unsplit(self.msg);
115            self.header_and_value[0..9].fill(0);
116            Ok(SerializedValue::from_bytes_mut(self.header_and_value))
117        } else {
118            Err(MessageDeserializeError::TrailingData)
119        }
120    }
121
122    pub fn finish_discard_value(self) -> Result<(), MessageDeserializeError> {
123        if self.msg.is_empty() {
124            Ok(())
125        } else {
126            Err(MessageDeserializeError::TrailingData)
127        }
128    }
129}
130
131#[derive(Error, Debug, Copy, Clone, PartialEq, Eq)]
132pub enum MessageDeserializeError {
133    #[error("invalid serialization")]
134    InvalidSerialization,
135
136    #[error("unexpected end of input")]
137    UnexpectedEoi,
138
139    #[error("unexpected message type")]
140    UnexpectedMessage,
141
142    #[error("serialization contains trailing data")]
143    TrailingData,
144}