1use crate::coding::{Decode, Encode};
6use byteorder::{ReadBytesExt, WriteBytesExt};
7use std::io::{Read, Write};
8
9#[cfg(feature = "zstd")]
10use std::sync::Arc;
11
12#[cfg(feature = "zstd")]
32#[derive(Clone)]
33pub struct ZstdDictionary {
34 id: u32,
35 raw: Arc<[u8]>,
36}
37
38#[cfg(feature = "zstd")]
39impl ZstdDictionary {
40 #[must_use]
46 pub fn new(raw: &[u8]) -> Self {
47 Self {
48 id: compute_dict_id(raw),
49 raw: Arc::from(raw),
50 }
51 }
52
53 #[must_use]
55 pub fn id(&self) -> u32 {
56 self.id
57 }
58
59 #[must_use]
61 pub fn raw(&self) -> &[u8] {
62 &self.raw
63 }
64}
65
66#[cfg(feature = "zstd")]
67impl std::fmt::Debug for ZstdDictionary {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 f.debug_struct("ZstdDictionary")
70 .field("id", &format_args!("{:#010x}", self.id))
71 .field("size", &self.raw.len())
72 .finish()
73 }
74}
75
76#[cfg(feature = "zstd")]
78#[expect(
79 clippy::cast_possible_truncation,
80 reason = "intentionally truncated to 32-bit fingerprint"
81)]
82fn compute_dict_id(raw: &[u8]) -> u32 {
83 xxhash_rust::xxh3::xxh3_64(raw) as u32
84}
85
86#[derive(Copy, Clone, Debug, Eq, PartialEq)]
88#[non_exhaustive]
89pub enum CompressionType {
90 None,
94
95 #[cfg(feature = "lz4")]
100 Lz4,
101
102 #[cfg(feature = "zstd")]
119 Zstd(i32),
120
121 #[cfg(feature = "zstd")]
131 ZstdDict {
132 level: i32,
134
135 dict_id: u32,
137 },
138}
139
140impl CompressionType {
141 #[cfg(feature = "zstd")]
145 fn validate_zstd_level(level: i32) -> crate::Result<()> {
146 if !(1..=22).contains(&level) {
147 return Err(crate::Error::Io(std::io::Error::other(format!(
150 "invalid zstd compression level {level}, expected 1..=22"
151 ))));
152 }
153 Ok(())
154 }
155
156 #[cfg(feature = "zstd")]
165 pub fn zstd(level: i32) -> crate::Result<Self> {
166 Self::validate_zstd_level(level)?;
167 Ok(Self::Zstd(level))
168 }
169
170 #[cfg(feature = "zstd")]
180 pub fn zstd_dict(level: i32, dict_id: u32) -> crate::Result<Self> {
181 Self::validate_zstd_level(level)?;
182 Ok(Self::ZstdDict { level, dict_id })
183 }
184}
185
186impl std::fmt::Display for CompressionType {
187 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188 write!(
189 f,
190 "{}",
191 match self {
192 Self::None => "none",
193
194 #[cfg(feature = "lz4")]
195 Self::Lz4 => "lz4",
196
197 #[cfg(feature = "zstd")]
198 Self::Zstd(_) => "zstd",
199
200 #[cfg(feature = "zstd")]
201 Self::ZstdDict { .. } => "zstd+dict",
202 }
203 )
204 }
205}
206
207impl Encode for CompressionType {
208 fn encode_into<W: Write>(&self, writer: &mut W) -> Result<(), crate::Error> {
209 match self {
210 Self::None => {
211 writer.write_u8(0)?;
212 }
213
214 #[cfg(feature = "lz4")]
215 Self::Lz4 => {
216 writer.write_u8(1)?;
217 }
218
219 #[cfg(feature = "zstd")]
220 Self::Zstd(level) => {
221 writer.write_u8(3)?;
222 debug_assert!(
225 (1..=22).contains(level),
226 "zstd level {level} outside valid range 1..=22"
227 );
228 #[expect(
229 clippy::cast_possible_truncation,
230 reason = "level range 1..=22 fits i8"
231 )]
232 writer.write_i8(*level as i8)?;
233 }
234
235 #[cfg(feature = "zstd")]
236 Self::ZstdDict { level, dict_id } => {
237 writer.write_u8(4)?;
238 debug_assert!(
239 (1..=22).contains(level),
240 "zstd level {level} outside valid range 1..=22"
241 );
242 #[expect(
243 clippy::cast_possible_truncation,
244 reason = "level range 1..=22 fits i8"
245 )]
246 writer.write_i8(*level as i8)?;
247 byteorder::WriteBytesExt::write_u32::<byteorder::LittleEndian>(writer, *dict_id)?;
248 }
249 }
250
251 Ok(())
252 }
253}
254
255impl Decode for CompressionType {
256 fn decode_from<R: Read>(reader: &mut R) -> Result<Self, crate::Error> {
257 let tag = reader.read_u8()?;
258
259 match tag {
260 0 => Ok(Self::None),
261
262 #[cfg(feature = "lz4")]
263 1 => Ok(Self::Lz4),
264
265 #[cfg(feature = "zstd")]
266 3 => {
267 let level = i32::from(reader.read_i8()?);
268 Self::validate_zstd_level(level)?;
270 Ok(Self::Zstd(level))
271 }
272
273 #[cfg(feature = "zstd")]
274 4 => {
275 let level = i32::from(reader.read_i8()?);
276 Self::validate_zstd_level(level)?;
277 let dict_id = byteorder::ReadBytesExt::read_u32::<byteorder::LittleEndian>(reader)?;
278 Ok(Self::ZstdDict { level, dict_id })
279 }
280
281 tag => Err(crate::Error::InvalidTag(("CompressionType", tag))),
282 }
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289 use test_log::test;
290
291 #[test]
292 fn compression_serialize_none() {
293 let serialized = CompressionType::None.encode_into_vec();
294 assert_eq!(1, serialized.len());
295 }
296
297 #[cfg(feature = "lz4")]
298 mod lz4 {
299 use super::*;
300 use test_log::test;
301
302 #[test]
303 fn compression_serialize_lz4() {
304 let serialized = CompressionType::Lz4.encode_into_vec();
305 assert_eq!(1, serialized.len());
306 }
307 }
308
309 #[cfg(feature = "zstd")]
310 mod zstd {
311 use super::*;
312 use test_log::test;
313
314 #[test]
315 fn compression_serialize_zstd() {
316 let serialized = CompressionType::Zstd(3).encode_into_vec();
317 assert_eq!(2, serialized.len());
318 }
319
320 #[test]
321 fn compression_roundtrip_zstd() {
322 for level in [1, 3, 9, 19] {
323 let original = CompressionType::Zstd(level);
324 let serialized = original.encode_into_vec();
325 let decoded =
326 CompressionType::decode_from(&mut &serialized[..]).expect("decode failed");
327 assert_eq!(original, decoded);
328 }
329 }
330
331 #[test]
332 fn compression_display_zstd() {
333 assert_eq!(format!("{}", CompressionType::Zstd(3)), "zstd");
334 }
335
336 #[test]
337 fn compression_zstd_rejects_invalid_level() {
338 for invalid_level in [0, 23, -1, 200] {
339 let result = CompressionType::zstd(invalid_level);
340 assert!(result.is_err(), "level {invalid_level} should be rejected");
341 }
342 }
343
344 #[test]
345 fn compression_zstd_decode_rejects_invalid_level() {
346 let valid = CompressionType::Zstd(3).encode_into_vec();
348 assert_eq!(valid.len(), 2);
349
350 let corrupted = vec![valid[0], 0];
352 let result = CompressionType::decode_from(&mut &corrupted[..]);
353 assert!(result.is_err(), "level 0 should be rejected on decode");
354
355 let corrupted = vec![valid[0], 23];
357 let result = CompressionType::decode_from(&mut &corrupted[..]);
358 assert!(result.is_err(), "level 23 should be rejected on decode");
359 }
360
361 #[test]
362 fn compression_serialize_zstd_dict() {
363 let serialized = CompressionType::ZstdDict {
364 level: 3,
365 dict_id: 0xDEAD_BEEF,
366 }
367 .encode_into_vec();
368 assert_eq!(serialized, [4, 3, 0xEF, 0xBE, 0xAD, 0xDE]);
370 }
371
372 #[test]
373 fn compression_roundtrip_zstd_dict() {
374 for level in [1, 3, 9, 19] {
375 for dict_id in [0, 1, 0xDEAD_BEEF, u32::MAX] {
376 let original = CompressionType::ZstdDict { level, dict_id };
377 let serialized = original.encode_into_vec();
378 let decoded =
379 CompressionType::decode_from(&mut &serialized[..]).expect("decode failed");
380 assert_eq!(original, decoded);
381 }
382 }
383 }
384
385 #[test]
386 fn compression_display_zstd_dict() {
387 assert_eq!(
388 format!(
389 "{}",
390 CompressionType::ZstdDict {
391 level: 3,
392 dict_id: 42
393 }
394 ),
395 "zstd+dict"
396 );
397 }
398
399 #[test]
400 fn compression_zstd_dict_rejects_invalid_level() {
401 for invalid_level in [0, 23, -1, 200] {
402 let result = CompressionType::zstd_dict(invalid_level, 42);
403 assert!(result.is_err(), "level {invalid_level} should be rejected");
404 }
405 }
406
407 #[test]
408 fn compression_zstd_dict_decode_rejects_invalid_level() {
409 let mut buf = CompressionType::ZstdDict {
411 level: 3,
412 dict_id: 42,
413 }
414 .encode_into_vec();
415 assert_eq!(buf[0], 4); buf[1] = 0; let result = CompressionType::decode_from(&mut &buf[..]);
419 assert!(result.is_err(), "level 0 should be rejected on decode");
420 }
421
422 #[test]
423 fn zstd_dictionary_id_deterministic() {
424 let dict_bytes = b"sample dictionary content for testing";
425 let d1 = ZstdDictionary::new(dict_bytes);
426 let d2 = ZstdDictionary::new(dict_bytes);
427 assert_eq!(d1.id(), d2.id());
428 }
429
430 #[test]
431 fn zstd_dictionary_different_content_different_id() {
432 let d1 = ZstdDictionary::new(b"dictionary one");
433 let d2 = ZstdDictionary::new(b"dictionary two");
434 assert_ne!(d1.id(), d2.id());
435 }
436
437 #[test]
438 fn zstd_dictionary_raw_roundtrip() {
439 let raw = b"my dictionary bytes";
440 let dict = ZstdDictionary::new(raw);
441 assert_eq!(dict.raw(), raw);
442 }
443
444 #[test]
445 fn zstd_dictionary_debug_format() {
446 let dict = ZstdDictionary::new(b"test");
447 let debug = format!("{dict:?}");
448 assert!(debug.contains("ZstdDictionary"));
449 assert!(debug.contains("size: 4"));
450 }
451 }
452}