aldrin_core/
message_deserializer.rs1use crate::buf_ext::MessageBufExt;
2use crate::message::MessageKind;
3use crate::serialized_value::SerializedValue;
4use bytes::{Buf, BytesMut};
5use thiserror::Error;
6use uuid::Uuid;
7
8pub(crate) struct MessageWithoutValueDeserializer {
9 buf: BytesMut,
10}
11
12impl MessageWithoutValueDeserializer {
13 pub fn new(mut buf: BytesMut, kind: MessageKind) -> Result<Self, MessageDeserializeError> {
14 let buf_len = buf.len();
15
16 if buf_len < 5 {
18 return Err(MessageDeserializeError::UnexpectedEoi);
19 }
20
21 let len = buf.get_u32_le() as usize;
22 if buf_len != len {
23 return Err(MessageDeserializeError::InvalidSerialization);
24 }
25
26 buf.ensure_discriminant_u8(kind)?;
27
28 Ok(Self { buf })
29 }
30
31 pub fn try_get_discriminant_u8<T: TryFrom<u8>>(
32 &mut self,
33 ) -> Result<T, MessageDeserializeError> {
34 self.buf.try_get_discriminant_u8()
35 }
36
37 pub fn try_get_varint_u32_le(&mut self) -> Result<u32, MessageDeserializeError> {
38 self.buf.try_get_varint_u32_le()
39 }
40
41 pub fn try_get_uuid(&mut self) -> Result<Uuid, MessageDeserializeError> {
42 let mut bytes = uuid::Bytes::default();
43 self.buf.try_copy_to_slice(&mut bytes)?;
44 Ok(Uuid::from_bytes(bytes))
45 }
46
47 pub fn finish(self) -> Result<(), MessageDeserializeError> {
48 if self.buf.is_empty() {
49 Ok(())
50 } else {
51 Err(MessageDeserializeError::TrailingData)
52 }
53 }
54}
55
56pub(crate) struct MessageWithValueDeserializer {
57 header_and_value: BytesMut,
58 msg: BytesMut,
59}
60
61impl MessageWithValueDeserializer {
62 pub fn new(mut buf: BytesMut, kind: MessageKind) -> Result<Self, MessageDeserializeError> {
63 debug_assert!(kind.has_value());
64
65 if buf.len() < 10 {
68 return Err(MessageDeserializeError::UnexpectedEoi);
69 }
70
71 let msg_len = (&buf[..4]).get_u32_le() as usize;
72 if buf.len() != msg_len {
73 return Err(MessageDeserializeError::InvalidSerialization);
74 }
75
76 if buf[4] != kind.into() {
77 return Err(MessageDeserializeError::UnexpectedMessage);
78 }
79
80 let value_len = (&buf[5..9]).get_u32_le() as usize;
81 let max_value_len = buf.len() - 9;
82
83 if value_len < 1 {
84 return Err(MessageDeserializeError::InvalidSerialization);
85 } else if value_len > max_value_len {
86 return Err(MessageDeserializeError::UnexpectedEoi);
87 }
88
89 let msg = buf.split_off(9 + value_len);
90 Ok(Self {
91 header_and_value: buf,
92 msg,
93 })
94 }
95
96 pub fn try_get_discriminant_u8<T: TryFrom<u8>>(
97 &mut self,
98 ) -> Result<T, MessageDeserializeError> {
99 self.msg.try_get_discriminant_u8()
100 }
101
102 pub fn try_get_varint_u32_le(&mut self) -> Result<u32, MessageDeserializeError> {
103 self.msg.try_get_varint_u32_le()
104 }
105
106 pub fn try_get_uuid(&mut self) -> Result<Uuid, MessageDeserializeError> {
107 let mut bytes = uuid::Bytes::default();
108 self.msg.try_copy_to_slice(&mut bytes)?;
109 Ok(Uuid::from_bytes(bytes))
110 }
111
112 pub fn finish(mut self) -> Result<SerializedValue, MessageDeserializeError> {
113 if self.msg.is_empty() {
114 self.header_and_value.unsplit(self.msg);
115 self.header_and_value[0..9].fill(0);
116 Ok(SerializedValue::from_bytes_mut(self.header_and_value))
117 } else {
118 Err(MessageDeserializeError::TrailingData)
119 }
120 }
121
122 pub fn finish_discard_value(self) -> Result<(), MessageDeserializeError> {
123 if self.msg.is_empty() {
124 Ok(())
125 } else {
126 Err(MessageDeserializeError::TrailingData)
127 }
128 }
129}
130
131#[derive(Error, Debug, Copy, Clone, PartialEq, Eq)]
132pub enum MessageDeserializeError {
133 #[error("invalid serialization")]
134 InvalidSerialization,
135
136 #[error("unexpected end of input")]
137 UnexpectedEoi,
138
139 #[error("unexpected message type")]
140 UnexpectedMessage,
141
142 #[error("serialization contains trailing data")]
143 TrailingData,
144}