memberlist_proto/
data.rs

1use std::borrow::Cow;
2
3use const_varint::{decode_u32_varint, encode_u32_varint_to, encoded_u32_varint_len};
4
5use super::WireType;
6
7pub use tuple::TupleEncoder;
8
9#[cfg(any(feature = "std", feature = "alloc"))]
10mod bytes;
11#[cfg(any(feature = "std", feature = "alloc"))]
12mod nodecraft;
13mod primitives;
14#[cfg(any(feature = "std", feature = "alloc"))]
15mod string;
16
17mod tuple;
18
19/// The reference type of the data.
20pub trait DataRef<'a, D>
21where
22  D: Data + ?Sized,
23  Self: Copy + core::fmt::Debug + Send + Sync,
24{
25  /// Decodes the reference type from a buffer.
26  ///
27  /// The entire buffer will be consumed.
28  fn decode(buf: &'a [u8]) -> Result<(usize, Self), DecodeError>
29  where
30    Self: Sized;
31
32  /// Decodes a length-delimited reference instance of the message from the buffer.
33  fn decode_length_delimited(src: &'a [u8]) -> Result<(usize, Self), DecodeError>
34  where
35    Self: Sized,
36  {
37    if D::WIRE_TYPE != WireType::LengthDelimited {
38      return Self::decode(src);
39    }
40
41    let (mut offset, len) = decode_u32_varint(src)?;
42    let len = len as usize;
43    if len + offset > src.len() {
44      return Err(DecodeError::buffer_underflow());
45    }
46
47    let src = &src[offset..offset + len];
48    let (bytes_read, value) = Self::decode(src)?;
49
50    #[cfg(debug_assertions)]
51    super::debug_assert_read_eq::<Self>(bytes_read, len);
52
53    offset += bytes_read;
54    Ok((offset, value))
55  }
56}
57
58/// The memberlist data can be transmitted through the network.
59pub trait Data: core::fmt::Debug + Send + Sync {
60  /// The wire type of the data.
61  const WIRE_TYPE: WireType = WireType::LengthDelimited;
62
63  /// The reference type of the data.
64  type Ref<'a>: DataRef<'a, Self>;
65
66  /// Converts the reference type to the owned type.
67  fn from_ref(val: Self::Ref<'_>) -> Result<Self, DecodeError>
68  where
69    Self: Sized;
70
71  /// Returns the encoded length of the data only considering the data itself, (e.g. no length prefix, no wire type).
72  fn encoded_len(&self) -> usize;
73
74  /// Returns the encoded length of the data including the length delimited.
75  fn encoded_len_with_length_delimited(&self) -> usize {
76    let len = self.encoded_len();
77    match Self::WIRE_TYPE {
78      WireType::LengthDelimited => encoded_u32_varint_len(len as u32) + len,
79      _ => len,
80    }
81  }
82
83  /// Encodes the message to a buffer.
84  ///
85  /// An error will be returned if the buffer does not have sufficient capacity.
86  fn encode(&self, buf: &mut [u8]) -> Result<usize, EncodeError>;
87
88  /// Encodes the message into a vec.
89  #[cfg(any(feature = "std", feature = "alloc"))]
90  fn encode_to_vec(&self) -> Result<std::vec::Vec<u8>, EncodeError> {
91    let len = self.encoded_len();
92    let mut vec = std::vec![0; len];
93    self.encode(&mut vec).map(|_| vec)
94  }
95
96  /// Encodes the message into a [`Bytes`](::bytes::Bytes).
97  #[cfg(any(feature = "std", feature = "alloc"))]
98  fn encode_to_bytes(&self) -> Result<::bytes::Bytes, EncodeError> {
99    self.encode_to_vec().map(Into::into)
100  }
101
102  /// Encodes the message with a length-delimiter to a buffer.
103  ///
104  /// An error will be returned if the buffer does not have sufficient capacity.
105  fn encode_length_delimited(&self, buf: &mut [u8]) -> Result<usize, EncodeError> {
106    if Self::WIRE_TYPE != WireType::LengthDelimited {
107      return self.encode(buf);
108    }
109
110    let len = self.encoded_len();
111    if len > u32::MAX as usize {
112      return Err(EncodeError::TooLarge);
113    }
114
115    let mut offset = 0;
116    offset += encode_u32_varint_to(len as u32, buf)?;
117    offset += self.encode(&mut buf[offset..])?;
118
119    #[cfg(debug_assertions)]
120    super::debug_assert_write_eq::<Self>(offset, self.encoded_len_with_length_delimited());
121
122    Ok(offset)
123  }
124
125  /// Encodes the message with a length-delimiter into a vec.
126  #[cfg(any(feature = "std", feature = "alloc"))]
127  fn encode_length_delimited_to_vec(&self) -> Result<std::vec::Vec<u8>, EncodeError> {
128    let len = self.encoded_len_with_length_delimited();
129    let mut vec = std::vec![0; len];
130    self.encode_length_delimited(&mut vec).map(|_| vec)
131  }
132
133  /// Encodes the message with a length-delimiter into a [`Bytes`](::bytes::Bytes).
134  #[cfg(any(feature = "std", feature = "alloc"))]
135  fn encode_length_delimited_to_bytes(&self) -> Result<::bytes::Bytes, EncodeError> {
136    self.encode_length_delimited_to_vec().map(Into::into)
137  }
138
139  /// Decodes an instance of the message from a buffer.
140  ///
141  /// The entire buffer will be consumed.
142  fn decode(src: &[u8]) -> Result<(usize, Self), DecodeError>
143  where
144    Self: Sized,
145  {
146    <Self::Ref<'_> as DataRef<Self>>::decode(src)
147      .and_then(|(bytes_read, value)| Self::from_ref(value).map(|val| (bytes_read, val)))
148  }
149
150  /// Decodes a length-delimited instance of the message from the buffer.
151  fn decode_length_delimited(buf: &[u8]) -> Result<(usize, Self), DecodeError>
152  where
153    Self: Sized,
154  {
155    <Self::Ref<'_> as DataRef<Self>>::decode_length_delimited(buf)
156      .and_then(|(bytes_read, value)| Self::from_ref(value).map(|val| (bytes_read, val)))
157  }
158}
159
160/// A data encoding error
161#[derive(Debug, thiserror::Error)]
162pub enum EncodeError {
163  /// Returned when the encoded buffer is too small to hold the bytes format of the types.
164  #[error("insufficient buffer capacity, required: {required}, remaining: {remaining}")]
165  InsufficientBuffer {
166    /// The required buffer capacity.
167    required: usize,
168    /// The remaining buffer capacity.
169    remaining: usize,
170  },
171  /// Returned when the data in encoded format is larger than the maximum allowed size.
172  #[error("encoded data is too large, the maximum allowed size is {MAX} bytes", MAX = u32::MAX)]
173  TooLarge,
174  /// A custom encoding error.
175  #[error("{0}")]
176  Custom(Cow<'static, str>),
177}
178
179impl EncodeError {
180  /// Creates an insufficient buffer error.
181  #[inline]
182  pub const fn insufficient_buffer(required: usize, remaining: usize) -> Self {
183    Self::InsufficientBuffer {
184      required,
185      remaining,
186    }
187  }
188
189  /// Creates a custom encoding error.
190  pub fn custom<T>(value: T) -> Self
191  where
192    T: Into<Cow<'static, str>>,
193  {
194    Self::Custom(value.into())
195  }
196
197  /// Update the error with the required and remaining buffer capacity.
198  pub fn update(mut self, required: usize, remaining: usize) -> Self {
199    match self {
200      Self::InsufficientBuffer {
201        required: ref mut r,
202        remaining: ref mut rem,
203      } => {
204        *r = required;
205        *rem = remaining;
206        self
207      }
208      _ => self,
209    }
210  }
211}
212
213impl From<const_varint::EncodeError> for EncodeError {
214  #[inline]
215  fn from(value: const_varint::EncodeError) -> Self {
216    match value {
217      const_varint::EncodeError::Underflow {
218        required,
219        remaining,
220      } => Self::InsufficientBuffer {
221        required,
222        remaining,
223      },
224    }
225  }
226}
227
228impl From<Cow<'static, str>> for EncodeError {
229  fn from(value: Cow<'static, str>) -> Self {
230    Self::Custom(value)
231  }
232}
233
234/// A message decoding error.
235///
236/// `DecodeError` indicates that the input buffer does not contain a valid
237/// message. The error details should be considered 'best effort': in
238/// general it is not possible to exactly pinpoint why data is malformed.
239#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error, derive_more::IsVariant)]
240pub enum DecodeError {
241  /// Returned when the buffer does not have enough data to decode the message.
242  #[error("buffer underflow")]
243  BufferUnderflow,
244
245  /// Returned when the buffer does not contain the required field.
246  #[error("missing {field} in {ty}")]
247  MissingField {
248    /// The type of the message.
249    ty: &'static str,
250    /// The name of the field.
251    field: &'static str,
252  },
253
254  /// Returned when the buffer contains duplicate fields for the same tag in a message.
255  #[error("duplicate field {field} with tag {tag} in {ty}")]
256  DuplicateField {
257    /// The type of the message.
258    ty: &'static str,
259    /// The name of the field.
260    field: &'static str,
261    /// The tag of the field.
262    tag: u8,
263  },
264
265  /// Returned when there is a unknown wire type.
266  #[error("unknown wire type value {value} with tag {tag} when decoding {ty}")]
267  UnknownWireType {
268    /// The type of the message.
269    ty: &'static str,
270    /// The unknown wire type value.
271    value: u8,
272    /// The tag of the field.
273    tag: u8,
274  },
275
276  /// Returned when finding a unknown tag.
277  #[error("unknown tag {tag} when decoding {ty}")]
278  UnknownTag {
279    /// The type of the message.
280    ty: &'static str,
281    /// The unknown tag value.
282    tag: u8,
283  },
284
285  /// Returned when fail to decode the length-delimited
286  #[error("length-delimited overflow the maximum value of u32")]
287  LengthDelimitedOverflow,
288
289  /// A custom decoding error.
290  #[error("{0}")]
291  Custom(Cow<'static, str>),
292}
293
294impl From<const_varint::DecodeError> for DecodeError {
295  #[inline]
296  fn from(e: const_varint::DecodeError) -> Self {
297    match e {
298      const_varint::DecodeError::Underflow => Self::BufferUnderflow,
299      const_varint::DecodeError::Overflow => Self::LengthDelimitedOverflow,
300    }
301  }
302}
303
304impl DecodeError {
305  /// Creates a new buffer underflow decoding error.
306  #[inline]
307  pub const fn buffer_underflow() -> Self {
308    Self::BufferUnderflow
309  }
310
311  /// Creates a new missing field decoding error.
312  #[inline]
313  pub const fn missing_field(ty: &'static str, field: &'static str) -> Self {
314    Self::MissingField { ty, field }
315  }
316
317  /// Creates a new duplicate field decoding error.
318  #[inline]
319  pub const fn duplicate_field(ty: &'static str, field: &'static str, tag: u8) -> Self {
320    Self::DuplicateField { ty, field, tag }
321  }
322
323  /// Creates a new unknown wire type decoding error.
324  #[inline]
325  pub const fn unknown_wire_type(ty: &'static str, value: u8, tag: u8) -> Self {
326    Self::UnknownWireType { ty, value, tag }
327  }
328
329  /// Creates a new unknown tag decoding error.
330  #[inline]
331  pub const fn unknown_tag(ty: &'static str, tag: u8) -> Self {
332    Self::UnknownTag { ty, tag }
333  }
334
335  /// Creates a custom decoding error.
336  #[inline]
337  pub fn custom<T>(value: T) -> Self
338  where
339    T: Into<Cow<'static, str>>,
340  {
341    Self::Custom(value.into())
342  }
343}