memberlist_types/
label.rs1use bytes::{Buf, BufMut, Bytes, BytesMut};
2use nodecraft::CheapClone;
3
4#[derive(Debug, thiserror::Error)]
6pub enum InvalidLabel {
7 #[error("the size of label must between [0-255] bytes, got {0}")]
9 TooLarge(usize),
10 #[error(transparent)]
12 Utf8(#[from] core::str::Utf8Error),
13}
14
15#[derive(Clone)]
24pub struct Label(Bytes);
25
26impl CheapClone for Label {}
27
28impl Label {
29 pub const MAX_SIZE: usize = u8::MAX as usize - 2;
31
32 pub const TAG: u8 = 127;
34
35 #[inline]
37 pub const fn empty() -> Label {
38 Label(Bytes::new())
39 }
40
41 #[inline]
43 pub fn encoded_overhead(&self) -> usize {
44 if self.is_empty() {
45 0
46 } else {
47 2 + self.len()
48 }
49 }
50
51 #[inline]
53 pub fn from_static(s: &'static str) -> Result<Self, InvalidLabel> {
54 Self::try_from(s)
55 }
56
57 #[inline]
59 pub fn as_bytes(&self) -> &[u8] {
60 &self.0
61 }
62
63 #[inline]
65 pub fn as_str(&self) -> &str {
66 core::str::from_utf8(&self.0).unwrap()
67 }
68
69 #[inline]
71 pub fn is_empty(&self) -> bool {
72 self.0.is_empty()
73 }
74
75 #[inline]
77 pub fn len(&self) -> usize {
78 self.0.len()
79 }
80}
81
82#[cfg(feature = "serde")]
83const _: () = {
84 use serde::{Deserialize, Serialize};
85
86 impl Serialize for Label {
87 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
88 if serializer.is_human_readable() {
89 serializer.serialize_str(self.as_str())
90 } else {
91 serializer.serialize_bytes(self.as_bytes())
92 }
93 }
94 }
95
96 impl<'de> Deserialize<'de> for Label {
97 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
98 where
99 D: serde::Deserializer<'de>,
100 {
101 if deserializer.is_human_readable() {
102 String::deserialize(deserializer)
103 .and_then(|n| Label::try_from(n).map_err(serde::de::Error::custom))
104 } else {
105 Bytes::deserialize(deserializer)
106 .and_then(|n| Label::try_from(n).map_err(serde::de::Error::custom))
107 }
108 }
109 }
110};
111
112impl AsRef<str> for Label {
113 fn as_ref(&self) -> &str {
114 self.as_str()
115 }
116}
117
118impl core::cmp::PartialOrd for Label {
119 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
120 Some(self.cmp(other))
121 }
122}
123
124impl core::cmp::Ord for Label {
125 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
126 self.as_str().cmp(other.as_str())
127 }
128}
129
130impl core::cmp::PartialEq for Label {
131 fn eq(&self, other: &Self) -> bool {
132 self.as_str() == other.as_str()
133 }
134}
135
136impl core::cmp::PartialEq<str> for Label {
137 fn eq(&self, other: &str) -> bool {
138 self.as_str() == other
139 }
140}
141
142impl core::cmp::PartialEq<&str> for Label {
143 fn eq(&self, other: &&str) -> bool {
144 self.as_str() == *other
145 }
146}
147
148impl core::cmp::PartialEq<String> for Label {
149 fn eq(&self, other: &String) -> bool {
150 self.as_str() == other
151 }
152}
153
154impl core::cmp::PartialEq<&String> for Label {
155 fn eq(&self, other: &&String) -> bool {
156 self.as_str() == *other
157 }
158}
159
160impl core::cmp::Eq for Label {}
161
162impl core::hash::Hash for Label {
163 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
164 self.as_str().hash(state)
165 }
166}
167
168impl TryFrom<&str> for Label {
169 type Error = InvalidLabel;
170
171 fn try_from(s: &str) -> Result<Self, Self::Error> {
172 if s.len() > Self::MAX_SIZE {
173 return Err(InvalidLabel::TooLarge(s.len()));
174 }
175 Ok(Self(Bytes::copy_from_slice(s.as_bytes())))
176 }
177}
178
179impl TryFrom<&String> for Label {
180 type Error = InvalidLabel;
181
182 fn try_from(s: &String) -> Result<Self, Self::Error> {
183 s.as_str().try_into()
184 }
185}
186
187impl TryFrom<String> for Label {
188 type Error = InvalidLabel;
189
190 fn try_from(s: String) -> Result<Self, Self::Error> {
191 if s.len() > Self::MAX_SIZE {
192 return Err(InvalidLabel::TooLarge(s.len()));
193 }
194 Ok(Self(s.into()))
195 }
196}
197
198impl TryFrom<Bytes> for Label {
199 type Error = InvalidLabel;
200
201 fn try_from(s: Bytes) -> Result<Self, Self::Error> {
202 if s.len() > Self::MAX_SIZE {
203 return Err(InvalidLabel::TooLarge(s.len()));
204 }
205 match core::str::from_utf8(s.as_ref()) {
206 Ok(_) => Ok(Self(s)),
207 Err(e) => Err(InvalidLabel::Utf8(e)),
208 }
209 }
210}
211
212impl TryFrom<Vec<u8>> for Label {
213 type Error = InvalidLabel;
214
215 fn try_from(s: Vec<u8>) -> Result<Self, Self::Error> {
216 Label::try_from(Bytes::from(s))
217 }
218}
219
220impl TryFrom<&[u8]> for Label {
221 type Error = InvalidLabel;
222
223 fn try_from(s: &[u8]) -> Result<Self, Self::Error> {
224 if s.len() > Self::MAX_SIZE {
225 return Err(InvalidLabel::TooLarge(s.len()));
226 }
227 match core::str::from_utf8(s) {
228 Ok(_) => Ok(Self(Bytes::copy_from_slice(s))),
229 Err(e) => Err(InvalidLabel::Utf8(e)),
230 }
231 }
232}
233
234impl TryFrom<&Bytes> for Label {
235 type Error = InvalidLabel;
236
237 fn try_from(s: &Bytes) -> Result<Self, Self::Error> {
238 if s.len() > Self::MAX_SIZE {
239 return Err(InvalidLabel::TooLarge(s.len()));
240 }
241 match core::str::from_utf8(s.as_ref()) {
242 Ok(_) => Ok(Self(s.clone())),
243 Err(e) => Err(InvalidLabel::Utf8(e)),
244 }
245 }
246}
247
248impl TryFrom<BytesMut> for Label {
249 type Error = InvalidLabel;
250
251 fn try_from(s: BytesMut) -> Result<Self, Self::Error> {
252 if s.len() > Self::MAX_SIZE {
253 return Err(InvalidLabel::TooLarge(s.len()));
254 }
255 match core::str::from_utf8(s.as_ref()) {
256 Ok(_) => Ok(Self(s.freeze())),
257 Err(e) => Err(InvalidLabel::Utf8(e)),
258 }
259 }
260}
261
262impl core::fmt::Debug for Label {
263 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
264 write!(f, "{}", self.as_str())
265 }
266}
267
268impl core::fmt::Display for Label {
269 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
270 write!(f, "{}", self.as_str())
271 }
272}
273
274#[derive(Debug, thiserror::Error)]
276pub enum LabelError {
277 #[error(transparent)]
279 InvalidLabel(#[from] InvalidLabel),
280 #[error("not enough bytes to decode label")]
282 NotEnoughBytes,
283 #[error("label mismatch: expected {expected}, got {got}")]
285 LabelMismatch {
286 expected: Label,
288 got: Label,
290 },
291
292 #[error("unexpected double label header, inbound label check is disabled, but got double label header: local={local}, remote={remote}")]
294 Duplicate {
295 local: Label,
297 remote: Label,
299 },
300}
301
302impl LabelError {
303 pub fn mismatch(expected: Label, got: Label) -> Self {
305 Self::LabelMismatch { expected, got }
306 }
307
308 pub fn duplicate(local: Label, remote: Label) -> Self {
310 Self::Duplicate { local, remote }
311 }
312}
313
314pub trait LabelBufExt: Buf + sealed::Splitable + TryInto<Label, Error = InvalidLabel> {
316 fn remove_label_header(&mut self) -> Result<Option<Label>, LabelError>
318 where
319 Self: Sized,
320 {
321 if self.remaining() < 1 {
322 return Ok(None);
323 }
324
325 let data = self.chunk();
326 if data[0] != Label::TAG {
327 return Ok(None);
328 }
329 self.advance(1);
330 let len = self.get_u8() as usize;
331 if len > self.remaining() {
332 return Err(LabelError::NotEnoughBytes);
333 }
334 let label = self.split_to(len);
335 Self::try_into(label).map(Some).map_err(Into::into)
336 }
337}
338
339impl<T: Buf + sealed::Splitable + TryInto<Label, Error = InvalidLabel>> LabelBufExt for T {}
340
341pub trait LabelBufMutExt: BufMut {
343 fn add_label_header(&mut self, label: &Label) {
345 if label.is_empty() {
346 return;
347 }
348 self.put_u8(Label::TAG);
349 self.put_u8(label.len() as u8);
350 self.put_slice(label.as_bytes());
351 }
352}
353
354impl<T: BufMut> LabelBufMutExt for T {}
355
356mod sealed {
357 use bytes::{Bytes, BytesMut};
358
359 pub trait Splitable {
360 fn split_to(&mut self, len: usize) -> Self;
361 }
362
363 impl Splitable for BytesMut {
364 fn split_to(&mut self, len: usize) -> Self {
365 self.split_to(len)
366 }
367 }
368
369 impl Splitable for Bytes {
370 fn split_to(&mut self, len: usize) -> Self {
371 self.split_to(len)
372 }
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use std::hash::{Hash, Hasher};
379
380 use super::*;
381
382 #[test]
383 fn test_try_from_string() {
384 let label = Label::try_from("hello".to_string()).unwrap();
385 assert_eq!(label, "hello");
386
387 assert!(Label::try_from("a".repeat(256)).is_err());
388 }
389
390 #[test]
391 fn test_try_from_bytes() {
392 let label = Label::try_from(Bytes::from("hello")).unwrap();
393 assert_eq!(label, *"hello");
394
395 assert!(Label::try_from(Bytes::from("a".repeat(256).into_bytes())).is_err());
396 assert!(Label::try_from(Bytes::from_static(&[255; 25])).is_err());
397 }
398
399 #[test]
400 fn test_try_from_bytes_mut() {
401 let label = Label::try_from(BytesMut::from("hello")).unwrap();
402 assert_eq!(label, "hello".to_string());
403
404 assert!(Label::try_from(BytesMut::from([255; 25].as_slice())).is_err());
405 assert!(Label::try_from(BytesMut::from([0; 256].as_slice())).is_err());
406 }
407
408 #[test]
409 fn test_try_from_bytes_ref() {
410 let label = Label::try_from(&Bytes::from("hello")).unwrap();
411 assert_eq!(label, &"hello".to_string());
412
413 assert!(Label::try_from(&Bytes::from("a".repeat(256).into_bytes())).is_err());
414 assert!(Label::try_from(&Bytes::from_static(&[255; 25])).is_err());
415 }
416
417 #[test]
418 fn test_debug_and_hash() {
419 let label = Label::from_static("hello").unwrap();
420 assert_eq!(format!("{:?}", label), "hello");
421
422 let mut hasher = std::collections::hash_map::DefaultHasher::new();
423 label.hash(&mut hasher);
424 let h1 = hasher.finish();
425 let mut hasher = std::collections::hash_map::DefaultHasher::new();
426 "hello".hash(&mut hasher);
427 let h2 = hasher.finish();
428 assert_eq!(h1, h2);
429 assert_eq!(label.as_ref(), "hello");
430 }
431}