memberlist_quic/
compressor.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
use super::*;

pub(super) const COMPRESS_TAG: core::ops::RangeInclusive<u8> = 86..=126;

impl<A, S, W> From<UnknownCompressor> for super::QuicTransportError<A, S, W>
where
  A: AddressResolver,
  S: StreamLayer,
  W: Wire,
{
  fn from(err: UnknownCompressor) -> Self {
    Self::Compressor(err.into())
  }
}

impl<A, S, W> From<CompressError> for super::QuicTransportError<A, S, W>
where
  A: AddressResolver,
  S: StreamLayer,
  W: Wire,
{
  fn from(err: CompressError) -> Self {
    Self::Compressor(err.into())
  }
}

impl<A, S, W> From<DecompressError> for super::QuicTransportError<A, S, W>
where
  A: AddressResolver,
  S: StreamLayer,
  W: Wire,
{
  fn from(err: DecompressError) -> Self {
    Self::Compressor(err.into())
  }
}

/// Compress/Decompress errors.
#[derive(Debug, thiserror::Error)]
pub enum CompressorError {
  /// Compress errors
  #[error(transparent)]
  Compress(#[from] CompressError),
  /// Decompress errors
  #[error(transparent)]
  Decompress(#[from] DecompressError),
  /// Unknown compressor
  #[error(transparent)]
  UnknownCompressor(#[from] UnknownCompressor),
  /// Not enough bytes to decompress
  #[error("compressor: not enough bytes to decompress")]
  NotEnoughBytes,
}

/// Compressor for compress/decompress bytes for sending over the network.
#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Hash)]
#[repr(u8)]
#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
#[non_exhaustive]
pub enum Compressor {
  /// LZW decoder and encoder
  #[default]
  Lzw = { *COMPRESS_TAG.start() },
}

impl Compressor {
  /// The size of the compressor in bytes.
  pub const SIZE: usize = core::mem::size_of::<Self>();

  /// Returns true if the compressor is LZW.
  pub fn is_lzw(&self) -> bool {
    matches!(self, Self::Lzw)
  }
}

/// Unknown compressor
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct UnknownCompressor(u8);

impl core::fmt::Display for UnknownCompressor {
  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    write!(f, "unknown compressor {}", self.0)
  }
}

impl std::error::Error for UnknownCompressor {}

impl TryFrom<u8> for Compressor {
  type Error = UnknownCompressor;

  fn try_from(value: u8) -> Result<Self, Self::Error> {
    match value {
      v if v == (Compressor::Lzw as u8) => Ok(Self::Lzw),
      _ => Err(UnknownCompressor(value)),
    }
  }
}

const LZW_LIT_WIDTH: u8 = 8;

/// Compress errors.
#[derive(Debug, thiserror::Error)]
pub enum CompressError {
  /// LZW compress errors
  #[error(transparent)]
  Lzw(#[from] weezl::LzwError),
}

/// Decompress errors.
#[derive(Debug, thiserror::Error)]
pub enum DecompressError {
  /// LZW decompress errors
  #[error(transparent)]
  Lzw(#[from] weezl::LzwError),
}

impl Compressor {
  /// Decompresses the given buffer.
  pub fn decompress(&self, src: &[u8]) -> Result<Vec<u8>, DecompressError> {
    match self {
      Self::Lzw => weezl::decode::Decoder::new(weezl::BitOrder::Lsb, LZW_LIT_WIDTH)
        .decode(src)
        .map_err(DecompressError::Lzw),
    }
  }

  /// Compresses the given buffer.
  pub fn compress_into_bytes(&self, src: &[u8]) -> Result<Bytes, CompressError> {
    let mut buf = Vec::with_capacity(src.len());
    match self {
      Self::Lzw => weezl::encode::Encoder::new(weezl::BitOrder::Lsb, LZW_LIT_WIDTH)
        .into_vec(&mut buf)
        .encode_all(src)
        .status
        .map(|_| buf.into())
        .map_err(CompressError::Lzw),
    }
  }
}