use nodecraft::CheapClone;
use smol_str::SmolStr;
use super::{Data, DataRef, DecodeError, EncodeError};
#[derive(Debug, thiserror::Error)]
pub enum ParseLabelError {
#[error("the size of label must between [0-253] bytes, got {0}")]
TooLarge(usize),
#[error(transparent)]
Utf8(#[from] core::str::Utf8Error),
}
#[derive(Clone, Default, derive_more::Display)]
#[display("{_0}")]
#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
pub struct Label(pub(crate) SmolStr);
impl CheapClone for Label {}
impl Label {
pub const MAX_SIZE: usize = u8::MAX as usize - 2;
pub const EMPTY: &Label = &Label(SmolStr::new_inline(""));
#[inline]
pub const fn empty() -> Label {
Label(SmolStr::new_inline(""))
}
#[inline]
pub fn encoded_overhead(&self) -> usize {
if self.is_empty() { 0 } else { 2 + self.len() }
}
#[inline]
pub fn from_static(s: &'static str) -> Result<Self, ParseLabelError> {
Self::try_from(s)
}
#[inline]
pub fn as_bytes(&self) -> &[u8] {
self.0.as_bytes()
}
#[inline]
pub fn as_str(&self) -> &str {
self.0.as_str()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
#[inline]
pub fn len(&self) -> usize {
self.0.len()
}
}
impl AsRef<str> for Label {
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl core::borrow::Borrow<str> for Label {
fn borrow(&self) -> &str {
self.as_str()
}
}
impl core::cmp::PartialOrd for Label {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl core::cmp::Ord for Label {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
self.as_str().cmp(other.as_str())
}
}
impl core::cmp::PartialEq for Label {
fn eq(&self, other: &Self) -> bool {
self.as_str() == other.as_str()
}
}
impl core::cmp::PartialEq<str> for Label {
fn eq(&self, other: &str) -> bool {
self.as_str() == other
}
}
impl core::cmp::PartialEq<&str> for Label {
fn eq(&self, other: &&str) -> bool {
self.as_str() == *other
}
}
impl core::cmp::PartialEq<String> for Label {
fn eq(&self, other: &String) -> bool {
self.as_str() == other
}
}
impl core::cmp::PartialEq<&String> for Label {
fn eq(&self, other: &&String) -> bool {
self.as_str() == *other
}
}
impl core::cmp::Eq for Label {}
impl core::hash::Hash for Label {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.as_str().hash(state)
}
}
impl core::str::FromStr for Label {
type Err = ParseLabelError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
s.try_into()
}
}
impl TryFrom<&str> for Label {
type Error = ParseLabelError;
fn try_from(s: &str) -> Result<Self, Self::Error> {
if s.len() > Self::MAX_SIZE {
return Err(ParseLabelError::TooLarge(s.len()));
}
Ok(Self(SmolStr::new(s)))
}
}
impl TryFrom<&String> for Label {
type Error = ParseLabelError;
fn try_from(s: &String) -> Result<Self, Self::Error> {
s.as_str().try_into()
}
}
impl TryFrom<String> for Label {
type Error = ParseLabelError;
fn try_from(s: String) -> Result<Self, Self::Error> {
if s.len() > Self::MAX_SIZE {
return Err(ParseLabelError::TooLarge(s.len()));
}
Ok(Self(s.into()))
}
}
impl TryFrom<Vec<u8>> for Label {
type Error = ParseLabelError;
fn try_from(s: Vec<u8>) -> Result<Self, Self::Error> {
String::from_utf8(s)
.map_err(|e| e.utf8_error().into())
.and_then(Self::try_from)
}
}
impl TryFrom<&[u8]> for Label {
type Error = ParseLabelError;
fn try_from(s: &[u8]) -> Result<Self, Self::Error> {
if s.len() > Self::MAX_SIZE {
return Err(ParseLabelError::TooLarge(s.len()));
}
match core::str::from_utf8(s) {
Ok(s) => Ok(Self(SmolStr::new(s))),
Err(e) => Err(ParseLabelError::Utf8(e)),
}
}
}
impl core::fmt::Debug for Label {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, thiserror::Error)]
pub enum LabelError {
#[error(transparent)]
ParseLabelError(#[from] ParseLabelError),
#[error("not enough data to decode label")]
BufferUnderflow,
#[error("label mismatch: expected {expected}, got {got}")]
LabelMismatch {
expected: Label,
got: Label,
},
#[error(
"unexpected double label header, inbound label check is disabled, but got double label header: local={local}, remote={remote}"
)]
Duplicate {
local: Label,
remote: Label,
},
}
impl LabelError {
pub fn mismatch(expected: Label, got: Label) -> Self {
Self::LabelMismatch { expected, got }
}
pub fn duplicate(local: Label, remote: Label) -> Self {
Self::Duplicate { local, remote }
}
}
impl<'a> DataRef<'a, Label> for &'a str {
fn decode(buf: &'a [u8]) -> Result<(usize, Self), DecodeError>
where
Self: Sized,
{
let len = buf.len();
if len > Label::MAX_SIZE {
return Err(DecodeError::custom(
ParseLabelError::TooLarge(len).to_string(),
));
}
Ok((
len,
core::str::from_utf8(buf).map_err(|e| DecodeError::custom(e.to_string()))?,
))
}
}
impl Data for Label {
type Ref<'a> = &'a str;
fn from_ref(val: Self::Ref<'_>) -> Result<Self, DecodeError>
where
Self: Sized,
{
Ok(Self(SmolStr::new(val)))
}
fn encoded_len(&self) -> usize {
self.len()
}
fn encode(&self, buf: &mut [u8]) -> Result<usize, EncodeError> {
let len = self.len();
if len > buf.len() {
return Err(EncodeError::insufficient_buffer(len, buf.len()));
}
buf[..len].copy_from_slice(self.as_bytes());
Ok(len)
}
}
#[cfg(test)]
mod tests {
use core::hash::{Hash, Hasher};
use super::*;
#[test]
fn test_try_from_string() {
let label = Label::try_from("hello".to_string()).unwrap();
assert_eq!(label, "hello");
assert!(Label::try_from("a".repeat(256)).is_err());
}
#[test]
fn test_debug_and_hash() {
let label = Label::from_static("hello").unwrap();
assert_eq!(format!("{:?}", label), "hello");
let mut hasher = std::collections::hash_map::DefaultHasher::new();
label.hash(&mut hasher);
let h1 = hasher.finish();
let mut hasher = std::collections::hash_map::DefaultHasher::new();
"hello".hash(&mut hasher);
let h2 = hasher.finish();
assert_eq!(h1, h2);
assert_eq!(label.as_ref(), "hello");
}
}