hermes_core/compression/
zstd.rs1use std::io::{self, Write};
9
10#[derive(Debug, Clone, Copy)]
12pub struct CompressionLevel(pub i32);
13
14impl CompressionLevel {
15 pub const FAST: Self = Self(1);
17 pub const DEFAULT: Self = Self(3);
19 pub const BETTER: Self = Self(9);
21 pub const BEST: Self = Self(19);
23 pub const MAX: Self = Self(22);
25}
26
27impl Default for CompressionLevel {
28 fn default() -> Self {
29 Self::MAX }
31}
32
33#[derive(Clone)]
35pub struct CompressionDict {
36 raw_dict: Vec<u8>,
37}
38
39impl CompressionDict {
40 pub fn train(samples: &[&[u8]], dict_size: usize) -> io::Result<Self> {
45 let raw_dict = zstd::dict::from_samples(samples, dict_size).map_err(io::Error::other)?;
46 Ok(Self { raw_dict })
47 }
48
49 pub fn from_bytes(bytes: Vec<u8>) -> Self {
51 Self { raw_dict: bytes }
52 }
53
54 pub fn as_bytes(&self) -> &[u8] {
56 &self.raw_dict
57 }
58
59 pub fn len(&self) -> usize {
61 self.raw_dict.len()
62 }
63
64 pub fn is_empty(&self) -> bool {
66 self.raw_dict.is_empty()
67 }
68}
69
70pub fn compress(data: &[u8], level: CompressionLevel) -> io::Result<Vec<u8>> {
72 zstd::encode_all(data, level.0).map_err(io::Error::other)
73}
74
75pub fn compress_with_dict(
77 data: &[u8],
78 level: CompressionLevel,
79 dict: &CompressionDict,
80) -> io::Result<Vec<u8>> {
81 let mut encoder = zstd::Encoder::with_dictionary(Vec::new(), level.0, &dict.raw_dict)
82 .map_err(io::Error::other)?;
83 encoder.write_all(data)?;
84 encoder.finish().map_err(io::Error::other)
85}
86
87pub fn decompress(data: &[u8]) -> io::Result<Vec<u8>> {
93 zstd::bulk::decompress(data, 512 * 1024).map_err(io::Error::other)
98}
99
100pub fn decompress_with_dict(data: &[u8], dict: &CompressionDict) -> io::Result<Vec<u8>> {
105 let mut decompressor =
106 zstd::bulk::Decompressor::with_dictionary(&dict.raw_dict).map_err(io::Error::other)?;
107 decompressor
108 .decompress(data, 512 * 1024)
109 .map_err(io::Error::other)
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115
116 #[test]
117 fn test_roundtrip() {
118 let data = b"Hello, World! This is a test of compression.".repeat(100);
119 let compressed = compress(&data, CompressionLevel::default()).unwrap();
120 let decompressed = decompress(&compressed).unwrap();
121 assert_eq!(data, decompressed.as_slice());
122 assert!(compressed.len() < data.len());
123 }
124
125 #[test]
126 fn test_empty_data() {
127 let data: &[u8] = &[];
128 let compressed = compress(data, CompressionLevel::default()).unwrap();
129 let decompressed = decompress(&compressed).unwrap();
130 assert!(decompressed.is_empty());
131 }
132
133 #[test]
134 fn test_compression_levels() {
135 let data = b"Test data for compression levels".repeat(100);
136 for level in [1, 3, 9, 19] {
137 let compressed = compress(&data, CompressionLevel(level)).unwrap();
138 let decompressed = decompress(&compressed).unwrap();
139 assert_eq!(data.as_slice(), decompressed.as_slice());
140 }
141 }
142}