memberlist_types/
label.rs

1use bytes::{Buf, BufMut, Bytes, BytesMut};
2use nodecraft::CheapClone;
3
4/// Invalid label error.
5#[derive(Debug, thiserror::Error)]
6pub enum InvalidLabel {
7  /// The label is too large.
8  #[error("the size of label must between [0-255] bytes, got {0}")]
9  TooLarge(usize),
10  /// The label is not valid utf8.
11  #[error(transparent)]
12  Utf8(#[from] core::str::Utf8Error),
13}
14
15/// General approach is to prefix all packets and streams with the same structure:
16///
17/// Encode:
18/// ```text
19///   magic type byte (244): u8
20///   length of label name:  u8 (because labels can't be longer than 253 bytes)
21///   label name:            bytes (max 253 bytes)
22/// ```
23#[derive(Clone)]
24pub struct Label(Bytes);
25
26impl CheapClone for Label {}
27
28impl Label {
29  /// The maximum size of a name in bytes.
30  pub const MAX_SIZE: usize = u8::MAX as usize - 2;
31
32  /// The tag for a label when encoding/decoding.
33  pub const TAG: u8 = 127;
34
35  /// Create an empty label.
36  #[inline]
37  pub const fn empty() -> Label {
38    Label(Bytes::new())
39  }
40
41  /// The encoded overhead of a label.
42  #[inline]
43  pub fn encoded_overhead(&self) -> usize {
44    if self.is_empty() {
45      0
46    } else {
47      2 + self.len()
48    }
49  }
50
51  /// Create a label from a static str.
52  #[inline]
53  pub fn from_static(s: &'static str) -> Result<Self, InvalidLabel> {
54    Self::try_from(s)
55  }
56
57  /// Returns the label as a byte slice.
58  #[inline]
59  pub fn as_bytes(&self) -> &[u8] {
60    &self.0
61  }
62
63  /// Returns the str of the label.
64  #[inline]
65  pub fn as_str(&self) -> &str {
66    core::str::from_utf8(&self.0).unwrap()
67  }
68
69  /// Returns true if the label is empty.
70  #[inline]
71  pub fn is_empty(&self) -> bool {
72    self.0.is_empty()
73  }
74
75  /// Returns the length of the label in bytes.
76  #[inline]
77  pub fn len(&self) -> usize {
78    self.0.len()
79  }
80}
81
82#[cfg(feature = "serde")]
83const _: () = {
84  use serde::{Deserialize, Serialize};
85
86  impl Serialize for Label {
87    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
88      if serializer.is_human_readable() {
89        serializer.serialize_str(self.as_str())
90      } else {
91        serializer.serialize_bytes(self.as_bytes())
92      }
93    }
94  }
95
96  impl<'de> Deserialize<'de> for Label {
97    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
98    where
99      D: serde::Deserializer<'de>,
100    {
101      if deserializer.is_human_readable() {
102        String::deserialize(deserializer)
103          .and_then(|n| Label::try_from(n).map_err(serde::de::Error::custom))
104      } else {
105        Bytes::deserialize(deserializer)
106          .and_then(|n| Label::try_from(n).map_err(serde::de::Error::custom))
107      }
108    }
109  }
110};
111
112impl AsRef<str> for Label {
113  fn as_ref(&self) -> &str {
114    self.as_str()
115  }
116}
117
118impl core::cmp::PartialOrd for Label {
119  fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
120    Some(self.cmp(other))
121  }
122}
123
124impl core::cmp::Ord for Label {
125  fn cmp(&self, other: &Self) -> core::cmp::Ordering {
126    self.as_str().cmp(other.as_str())
127  }
128}
129
130impl core::cmp::PartialEq for Label {
131  fn eq(&self, other: &Self) -> bool {
132    self.as_str() == other.as_str()
133  }
134}
135
136impl core::cmp::PartialEq<str> for Label {
137  fn eq(&self, other: &str) -> bool {
138    self.as_str() == other
139  }
140}
141
142impl core::cmp::PartialEq<&str> for Label {
143  fn eq(&self, other: &&str) -> bool {
144    self.as_str() == *other
145  }
146}
147
148impl core::cmp::PartialEq<String> for Label {
149  fn eq(&self, other: &String) -> bool {
150    self.as_str() == other
151  }
152}
153
154impl core::cmp::PartialEq<&String> for Label {
155  fn eq(&self, other: &&String) -> bool {
156    self.as_str() == *other
157  }
158}
159
160impl core::cmp::Eq for Label {}
161
162impl core::hash::Hash for Label {
163  fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
164    self.as_str().hash(state)
165  }
166}
167
168impl TryFrom<&str> for Label {
169  type Error = InvalidLabel;
170
171  fn try_from(s: &str) -> Result<Self, Self::Error> {
172    if s.len() > Self::MAX_SIZE {
173      return Err(InvalidLabel::TooLarge(s.len()));
174    }
175    Ok(Self(Bytes::copy_from_slice(s.as_bytes())))
176  }
177}
178
179impl TryFrom<&String> for Label {
180  type Error = InvalidLabel;
181
182  fn try_from(s: &String) -> Result<Self, Self::Error> {
183    s.as_str().try_into()
184  }
185}
186
187impl TryFrom<String> for Label {
188  type Error = InvalidLabel;
189
190  fn try_from(s: String) -> Result<Self, Self::Error> {
191    if s.len() > Self::MAX_SIZE {
192      return Err(InvalidLabel::TooLarge(s.len()));
193    }
194    Ok(Self(s.into()))
195  }
196}
197
198impl TryFrom<Bytes> for Label {
199  type Error = InvalidLabel;
200
201  fn try_from(s: Bytes) -> Result<Self, Self::Error> {
202    if s.len() > Self::MAX_SIZE {
203      return Err(InvalidLabel::TooLarge(s.len()));
204    }
205    match core::str::from_utf8(s.as_ref()) {
206      Ok(_) => Ok(Self(s)),
207      Err(e) => Err(InvalidLabel::Utf8(e)),
208    }
209  }
210}
211
212impl TryFrom<Vec<u8>> for Label {
213  type Error = InvalidLabel;
214
215  fn try_from(s: Vec<u8>) -> Result<Self, Self::Error> {
216    Label::try_from(Bytes::from(s))
217  }
218}
219
220impl TryFrom<&[u8]> for Label {
221  type Error = InvalidLabel;
222
223  fn try_from(s: &[u8]) -> Result<Self, Self::Error> {
224    if s.len() > Self::MAX_SIZE {
225      return Err(InvalidLabel::TooLarge(s.len()));
226    }
227    match core::str::from_utf8(s) {
228      Ok(_) => Ok(Self(Bytes::copy_from_slice(s))),
229      Err(e) => Err(InvalidLabel::Utf8(e)),
230    }
231  }
232}
233
234impl TryFrom<&Bytes> for Label {
235  type Error = InvalidLabel;
236
237  fn try_from(s: &Bytes) -> Result<Self, Self::Error> {
238    if s.len() > Self::MAX_SIZE {
239      return Err(InvalidLabel::TooLarge(s.len()));
240    }
241    match core::str::from_utf8(s.as_ref()) {
242      Ok(_) => Ok(Self(s.clone())),
243      Err(e) => Err(InvalidLabel::Utf8(e)),
244    }
245  }
246}
247
248impl TryFrom<BytesMut> for Label {
249  type Error = InvalidLabel;
250
251  fn try_from(s: BytesMut) -> Result<Self, Self::Error> {
252    if s.len() > Self::MAX_SIZE {
253      return Err(InvalidLabel::TooLarge(s.len()));
254    }
255    match core::str::from_utf8(s.as_ref()) {
256      Ok(_) => Ok(Self(s.freeze())),
257      Err(e) => Err(InvalidLabel::Utf8(e)),
258    }
259  }
260}
261
262impl core::fmt::Debug for Label {
263  fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
264    write!(f, "{}", self.as_str())
265  }
266}
267
268impl core::fmt::Display for Label {
269  fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
270    write!(f, "{}", self.as_str())
271  }
272}
273
274/// Label error.
275#[derive(Debug, thiserror::Error)]
276pub enum LabelError {
277  /// Invalid label.
278  #[error(transparent)]
279  InvalidLabel(#[from] InvalidLabel),
280  /// Not enough bytes to decode label.
281  #[error("not enough bytes to decode label")]
282  NotEnoughBytes,
283  /// Label mismatch.
284  #[error("label mismatch: expected {expected}, got {got}")]
285  LabelMismatch {
286    /// Expected label.
287    expected: Label,
288    /// Got label.
289    got: Label,
290  },
291
292  /// Unexpected double label header
293  #[error("unexpected double label header, inbound label check is disabled, but got double label header: local={local}, remote={remote}")]
294  Duplicate {
295    /// The local label.
296    local: Label,
297    /// The remote label.
298    remote: Label,
299  },
300}
301
302impl LabelError {
303  /// Creates a new `LabelError::LabelMismatch`.
304  pub fn mismatch(expected: Label, got: Label) -> Self {
305    Self::LabelMismatch { expected, got }
306  }
307
308  /// Creates a new `LabelError::Duplicate`.
309  pub fn duplicate(local: Label, remote: Label) -> Self {
310    Self::Duplicate { local, remote }
311  }
312}
313
314/// Label extension for [`Buf`] types.
315pub trait LabelBufExt: Buf + sealed::Splitable + TryInto<Label, Error = InvalidLabel> {
316  /// Remove the label prefix from the buffer.
317  fn remove_label_header(&mut self) -> Result<Option<Label>, LabelError>
318  where
319    Self: Sized,
320  {
321    if self.remaining() < 1 {
322      return Ok(None);
323    }
324
325    let data = self.chunk();
326    if data[0] != Label::TAG {
327      return Ok(None);
328    }
329    self.advance(1);
330    let len = self.get_u8() as usize;
331    if len > self.remaining() {
332      return Err(LabelError::NotEnoughBytes);
333    }
334    let label = self.split_to(len);
335    Self::try_into(label).map(Some).map_err(Into::into)
336  }
337}
338
339impl<T: Buf + sealed::Splitable + TryInto<Label, Error = InvalidLabel>> LabelBufExt for T {}
340
341/// Label extension for [`BufMut`] types.
342pub trait LabelBufMutExt: BufMut {
343  /// Add label prefix to the buffer.
344  fn add_label_header(&mut self, label: &Label) {
345    if label.is_empty() {
346      return;
347    }
348    self.put_u8(Label::TAG);
349    self.put_u8(label.len() as u8);
350    self.put_slice(label.as_bytes());
351  }
352}
353
354impl<T: BufMut> LabelBufMutExt for T {}
355
356mod sealed {
357  use bytes::{Bytes, BytesMut};
358
359  pub trait Splitable {
360    fn split_to(&mut self, len: usize) -> Self;
361  }
362
363  impl Splitable for BytesMut {
364    fn split_to(&mut self, len: usize) -> Self {
365      self.split_to(len)
366    }
367  }
368
369  impl Splitable for Bytes {
370    fn split_to(&mut self, len: usize) -> Self {
371      self.split_to(len)
372    }
373  }
374}
375
376#[cfg(test)]
377mod tests {
378  use std::hash::{Hash, Hasher};
379
380  use super::*;
381
382  #[test]
383  fn test_try_from_string() {
384    let label = Label::try_from("hello".to_string()).unwrap();
385    assert_eq!(label, "hello");
386
387    assert!(Label::try_from("a".repeat(256)).is_err());
388  }
389
390  #[test]
391  fn test_try_from_bytes() {
392    let label = Label::try_from(Bytes::from("hello")).unwrap();
393    assert_eq!(label, *"hello");
394
395    assert!(Label::try_from(Bytes::from("a".repeat(256).into_bytes())).is_err());
396    assert!(Label::try_from(Bytes::from_static(&[255; 25])).is_err());
397  }
398
399  #[test]
400  fn test_try_from_bytes_mut() {
401    let label = Label::try_from(BytesMut::from("hello")).unwrap();
402    assert_eq!(label, "hello".to_string());
403
404    assert!(Label::try_from(BytesMut::from([255; 25].as_slice())).is_err());
405    assert!(Label::try_from(BytesMut::from([0; 256].as_slice())).is_err());
406  }
407
408  #[test]
409  fn test_try_from_bytes_ref() {
410    let label = Label::try_from(&Bytes::from("hello")).unwrap();
411    assert_eq!(label, &"hello".to_string());
412
413    assert!(Label::try_from(&Bytes::from("a".repeat(256).into_bytes())).is_err());
414    assert!(Label::try_from(&Bytes::from_static(&[255; 25])).is_err());
415  }
416
417  #[test]
418  fn test_debug_and_hash() {
419    let label = Label::from_static("hello").unwrap();
420    assert_eq!(format!("{:?}", label), "hello");
421
422    let mut hasher = std::collections::hash_map::DefaultHasher::new();
423    label.hash(&mut hasher);
424    let h1 = hasher.finish();
425    let mut hasher = std::collections::hash_map::DefaultHasher::new();
426    "hello".hash(&mut hasher);
427    let h2 = hasher.finish();
428    assert_eq!(h1, h2);
429    assert_eq!(label.as_ref(), "hello");
430  }
431}