Skip to main content

embeddenator_io/io/
envelope.rs

1use std::io;
2
3const MAGIC: [u8; 4] = *b"EDN1";
4const HEADER_LEN: usize = 16;
5
6#[repr(u8)]
7#[derive(Clone, Copy, Debug, PartialEq, Eq)]
8pub enum PayloadKind {
9    EngramBincode = 1,
10    SubEngramBincode = 2,
11}
12
13impl PayloadKind {
14    fn from_u8(v: u8) -> Option<Self> {
15        match v {
16            1 => Some(Self::EngramBincode),
17            2 => Some(Self::SubEngramBincode),
18            _ => None,
19        }
20    }
21}
22
23#[repr(u8)]
24#[derive(Clone, Copy, Debug, PartialEq, Eq)]
25pub enum CompressionCodec {
26    None = 0,
27    Zstd = 1,
28    Lz4 = 2,
29}
30
31impl CompressionCodec {
32    fn from_u8(v: u8) -> Option<Self> {
33        match v {
34            0 => Some(Self::None),
35            1 => Some(Self::Zstd),
36            2 => Some(Self::Lz4),
37            _ => None,
38        }
39    }
40}
41
42#[derive(Clone, Copy, Debug)]
43pub struct BinaryWriteOptions {
44    pub codec: CompressionCodec,
45    pub level: Option<i32>,
46}
47
48impl Default for BinaryWriteOptions {
49    fn default() -> Self {
50        Self {
51            codec: CompressionCodec::None,
52            level: None,
53        }
54    }
55}
56
57pub fn wrap_or_legacy(
58    kind: PayloadKind,
59    opts: BinaryWriteOptions,
60    raw: &[u8],
61) -> io::Result<Vec<u8>> {
62    if opts.codec == CompressionCodec::None {
63        return Ok(raw.to_vec());
64    }
65
66    let compressed = compress(opts.codec, raw, opts.level)?;
67
68    let mut out = Vec::with_capacity(HEADER_LEN + compressed.len());
69    out.extend_from_slice(&MAGIC);
70    out.push(kind as u8);
71    out.push(opts.codec as u8);
72    out.extend_from_slice(&0u16.to_le_bytes());
73    out.extend_from_slice(&(raw.len() as u64).to_le_bytes());
74    out.extend_from_slice(&compressed);
75
76    Ok(out)
77}
78
79pub fn unwrap_auto(expected_kind: PayloadKind, data: &[u8]) -> io::Result<Vec<u8>> {
80    if data.len() < HEADER_LEN || data[..4] != MAGIC {
81        return Ok(data.to_vec());
82    }
83
84    let kind = PayloadKind::from_u8(data[4])
85        .ok_or_else(|| io::Error::other("unknown envelope payload kind"))?;
86    if kind != expected_kind {
87        return Err(io::Error::other("unexpected envelope payload kind"));
88    }
89
90    let codec = CompressionCodec::from_u8(data[5])
91        .ok_or_else(|| io::Error::other("unknown envelope compression codec"))?;
92    let uncompressed_len =
93        u64::from_le_bytes(data[8..16].try_into().expect("slice length checked")) as usize;
94
95    let payload = &data[HEADER_LEN..];
96    let decoded = match codec {
97        CompressionCodec::None => payload.to_vec(),
98        CompressionCodec::Zstd | CompressionCodec::Lz4 => decompress(codec, payload)?,
99    };
100
101    if decoded.len() != uncompressed_len {
102        return Err(io::Error::other("envelope size mismatch"));
103    }
104
105    Ok(decoded)
106}
107
108fn compress(codec: CompressionCodec, raw: &[u8], level: Option<i32>) -> io::Result<Vec<u8>> {
109    match codec {
110        CompressionCodec::None => Ok(raw.to_vec()),
111        CompressionCodec::Zstd => compress_zstd(raw, level),
112        CompressionCodec::Lz4 => compress_lz4(raw),
113    }
114}
115
116fn decompress(codec: CompressionCodec, payload: &[u8]) -> io::Result<Vec<u8>> {
117    match codec {
118        CompressionCodec::None => Ok(payload.to_vec()),
119        CompressionCodec::Zstd => decompress_zstd(payload),
120        CompressionCodec::Lz4 => decompress_lz4(payload),
121    }
122}
123
124fn compress_zstd(_raw: &[u8], _level: Option<i32>) -> io::Result<Vec<u8>> {
125    #[cfg(feature = "compression-zstd")]
126    {
127        use std::io::Cursor;
128        let lvl = _level.unwrap_or(0);
129        zstd::stream::encode_all(Cursor::new(_raw), lvl).map_err(io::Error::other)
130    }
131
132    #[cfg(not(feature = "compression-zstd"))]
133    {
134        Err(io::Error::other(
135            "zstd compression support not enabled (enable feature `compression-zstd`)",
136        ))
137    }
138}
139
140fn decompress_zstd(_payload: &[u8]) -> io::Result<Vec<u8>> {
141    #[cfg(feature = "compression-zstd")]
142    {
143        use std::io::Cursor;
144        zstd::stream::decode_all(Cursor::new(_payload)).map_err(io::Error::other)
145    }
146
147    #[cfg(not(feature = "compression-zstd"))]
148    {
149        Err(io::Error::other(
150            "zstd decompression support not enabled (enable feature `compression-zstd`)",
151        ))
152    }
153}
154
155fn compress_lz4(_raw: &[u8]) -> io::Result<Vec<u8>> {
156    #[cfg(feature = "compression-lz4")]
157    {
158        Ok(lz4_flex::compress_prepend_size(_raw))
159    }
160
161    #[cfg(not(feature = "compression-lz4"))]
162    {
163        Err(io::Error::other(
164            "lz4 compression support not enabled (enable feature `compression-lz4`)",
165        ))
166    }
167}
168
169fn decompress_lz4(_payload: &[u8]) -> io::Result<Vec<u8>> {
170    #[cfg(feature = "compression-lz4")]
171    {
172        lz4_flex::decompress_size_prepended(_payload).map_err(io::Error::other)
173    }
174
175    #[cfg(not(feature = "compression-lz4"))]
176    {
177        Err(io::Error::other(
178            "lz4 decompression support not enabled (enable feature `compression-lz4`)",
179        ))
180    }
181}