memberlist_proto/
label.rs

1use nodecraft::CheapClone;
2use smol_str::SmolStr;
3
4use super::{Data, DataRef, DecodeError, EncodeError};
5
6/// Parse label error.
7#[derive(Debug, thiserror::Error)]
8pub enum ParseLabelError {
9  /// The label is too large.
10  #[error("the size of label must between [0-253] bytes, got {0}")]
11  TooLarge(usize),
12  /// The label is not valid utf8.
13  #[error(transparent)]
14  Utf8(#[from] core::str::Utf8Error),
15}
16
17/// General approach is to prefix all packets and streams with the same structure:
18///
19/// Encode:
20/// ```text
21///   magic type byte (127): u8
22///   length of label name:  u8 (because labels can't be longer than 253 bytes)
23///   label name:            bytes (max 253 bytes)
24/// ```
25#[derive(Clone, Default, derive_more::Display)]
26#[display("{_0}")]
27#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
28pub struct Label(pub(crate) SmolStr);
29
30impl CheapClone for Label {}
31
32impl Label {
33  /// The maximum size of a name in bytes.
34  pub const MAX_SIZE: usize = u8::MAX as usize - 2;
35
36  /// An empty label.
37  pub const EMPTY: &Label = &Label(SmolStr::new_inline(""));
38
39  /// Create an empty label.
40  #[inline]
41  pub const fn empty() -> Label {
42    Label(SmolStr::new_inline(""))
43  }
44
45  /// The encoded overhead of a label.
46  #[inline]
47  pub fn encoded_overhead(&self) -> usize {
48    if self.is_empty() { 0 } else { 2 + self.len() }
49  }
50
51  /// Create a label from a static str.
52  #[inline]
53  pub fn from_static(s: &'static str) -> Result<Self, ParseLabelError> {
54    Self::try_from(s)
55  }
56
57  /// Returns the label as a byte slice.
58  #[inline]
59  pub fn as_bytes(&self) -> &[u8] {
60    self.0.as_bytes()
61  }
62
63  /// Returns the str of the label.
64  #[inline]
65  pub fn as_str(&self) -> &str {
66    self.0.as_str()
67  }
68
69  /// Returns true if the label is empty.
70  #[inline]
71  pub fn is_empty(&self) -> bool {
72    self.0.is_empty()
73  }
74
75  /// Returns the length of the label in bytes.
76  #[inline]
77  pub fn len(&self) -> usize {
78    self.0.len()
79  }
80}
81
82impl AsRef<str> for Label {
83  fn as_ref(&self) -> &str {
84    self.as_str()
85  }
86}
87
88impl core::borrow::Borrow<str> for Label {
89  fn borrow(&self) -> &str {
90    self.as_str()
91  }
92}
93
94impl core::cmp::PartialOrd for Label {
95  fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
96    Some(self.cmp(other))
97  }
98}
99
100impl core::cmp::Ord for Label {
101  fn cmp(&self, other: &Self) -> core::cmp::Ordering {
102    self.as_str().cmp(other.as_str())
103  }
104}
105
106impl core::cmp::PartialEq for Label {
107  fn eq(&self, other: &Self) -> bool {
108    self.as_str() == other.as_str()
109  }
110}
111
112impl core::cmp::PartialEq<str> for Label {
113  fn eq(&self, other: &str) -> bool {
114    self.as_str() == other
115  }
116}
117
118impl core::cmp::PartialEq<&str> for Label {
119  fn eq(&self, other: &&str) -> bool {
120    self.as_str() == *other
121  }
122}
123
124impl core::cmp::PartialEq<String> for Label {
125  fn eq(&self, other: &String) -> bool {
126    self.as_str() == other
127  }
128}
129
130impl core::cmp::PartialEq<&String> for Label {
131  fn eq(&self, other: &&String) -> bool {
132    self.as_str() == *other
133  }
134}
135
136impl core::cmp::Eq for Label {}
137
138impl core::hash::Hash for Label {
139  fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
140    self.as_str().hash(state)
141  }
142}
143
144impl core::str::FromStr for Label {
145  type Err = ParseLabelError;
146
147  fn from_str(s: &str) -> Result<Self, Self::Err> {
148    s.try_into()
149  }
150}
151
152impl TryFrom<&str> for Label {
153  type Error = ParseLabelError;
154
155  fn try_from(s: &str) -> Result<Self, Self::Error> {
156    if s.len() > Self::MAX_SIZE {
157      return Err(ParseLabelError::TooLarge(s.len()));
158    }
159    Ok(Self(SmolStr::new(s)))
160  }
161}
162
163impl TryFrom<&String> for Label {
164  type Error = ParseLabelError;
165
166  fn try_from(s: &String) -> Result<Self, Self::Error> {
167    s.as_str().try_into()
168  }
169}
170
171impl TryFrom<String> for Label {
172  type Error = ParseLabelError;
173
174  fn try_from(s: String) -> Result<Self, Self::Error> {
175    if s.len() > Self::MAX_SIZE {
176      return Err(ParseLabelError::TooLarge(s.len()));
177    }
178    Ok(Self(s.into()))
179  }
180}
181
182impl TryFrom<Vec<u8>> for Label {
183  type Error = ParseLabelError;
184
185  fn try_from(s: Vec<u8>) -> Result<Self, Self::Error> {
186    String::from_utf8(s)
187      .map_err(|e| e.utf8_error().into())
188      .and_then(Self::try_from)
189  }
190}
191
192impl TryFrom<&[u8]> for Label {
193  type Error = ParseLabelError;
194
195  fn try_from(s: &[u8]) -> Result<Self, Self::Error> {
196    if s.len() > Self::MAX_SIZE {
197      return Err(ParseLabelError::TooLarge(s.len()));
198    }
199    match core::str::from_utf8(s) {
200      Ok(s) => Ok(Self(SmolStr::new(s))),
201      Err(e) => Err(ParseLabelError::Utf8(e)),
202    }
203  }
204}
205
206impl core::fmt::Debug for Label {
207  fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
208    write!(f, "{}", self.as_str())
209  }
210}
211
212/// Label error.
213#[derive(Debug, thiserror::Error)]
214pub enum LabelError {
215  /// Invalid label.
216  #[error(transparent)]
217  ParseLabelError(#[from] ParseLabelError),
218  /// Not enough data to decode label.
219  #[error("not enough data to decode label")]
220  BufferUnderflow,
221  /// Label mismatch.
222  #[error("label mismatch: expected {expected}, got {got}")]
223  LabelMismatch {
224    /// Expected label.
225    expected: Label,
226    /// Got label.
227    got: Label,
228  },
229
230  /// Unexpected double label header
231  #[error(
232    "unexpected double label header, inbound label check is disabled, but got double label header: local={local}, remote={remote}"
233  )]
234  Duplicate {
235    /// The local label.
236    local: Label,
237    /// The remote label.
238    remote: Label,
239  },
240}
241
242impl LabelError {
243  /// Creates a new `LabelError::LabelMismatch`.
244  pub fn mismatch(expected: Label, got: Label) -> Self {
245    Self::LabelMismatch { expected, got }
246  }
247
248  /// Creates a new `LabelError::Duplicate`.
249  pub fn duplicate(local: Label, remote: Label) -> Self {
250    Self::Duplicate { local, remote }
251  }
252}
253
254impl<'a> DataRef<'a, Label> for &'a str {
255  fn decode(buf: &'a [u8]) -> Result<(usize, Self), DecodeError>
256  where
257    Self: Sized,
258  {
259    let len = buf.len();
260    if len > Label::MAX_SIZE {
261      return Err(DecodeError::custom(
262        ParseLabelError::TooLarge(len).to_string(),
263      ));
264    }
265
266    Ok((
267      len,
268      core::str::from_utf8(buf).map_err(|e| DecodeError::custom(e.to_string()))?,
269    ))
270  }
271}
272
273impl Data for Label {
274  type Ref<'a> = &'a str;
275
276  fn from_ref(val: Self::Ref<'_>) -> Result<Self, DecodeError>
277  where
278    Self: Sized,
279  {
280    Ok(Self(SmolStr::new(val)))
281  }
282
283  fn encoded_len(&self) -> usize {
284    self.len()
285  }
286
287  fn encode(&self, buf: &mut [u8]) -> Result<usize, EncodeError> {
288    let len = self.len();
289    if len > buf.len() {
290      return Err(EncodeError::insufficient_buffer(len, buf.len()));
291    }
292    buf[..len].copy_from_slice(self.as_bytes());
293    Ok(len)
294  }
295}
296
297#[cfg(test)]
298mod tests {
299  use core::hash::{Hash, Hasher};
300
301  use super::*;
302
303  #[test]
304  fn test_try_from_string() {
305    let label = Label::try_from("hello".to_string()).unwrap();
306    assert_eq!(label, "hello");
307
308    assert!(Label::try_from("a".repeat(256)).is_err());
309  }
310
311  #[test]
312  fn test_debug_and_hash() {
313    let label = Label::from_static("hello").unwrap();
314    assert_eq!(format!("{:?}", label), "hello");
315
316    let mut hasher = std::collections::hash_map::DefaultHasher::new();
317    label.hash(&mut hasher);
318    let h1 = hasher.finish();
319    let mut hasher = std::collections::hash_map::DefaultHasher::new();
320    "hello".hash(&mut hasher);
321    let h2 = hasher.finish();
322    assert_eq!(h1, h2);
323    assert_eq!(label.as_ref(), "hello");
324  }
325}