1use bytes::Bytes;
8
9use crate::{
10 ChunkManifest, Codec, CodecError, CodecKind, DECOMPRESS_BOOTSTRAP_CAPACITY,
11 validate_decompress_manifest,
12};
13
14#[derive(Debug, Clone)]
18pub struct CpuZstd {
19 level: i32,
20}
21
22impl CpuZstd {
23 pub const DEFAULT_LEVEL: i32 = 3;
24
25 pub fn new(level: i32) -> Self {
26 Self {
27 level: level.clamp(1, 22),
28 }
29 }
30}
31
32impl Default for CpuZstd {
33 fn default() -> Self {
34 Self::new(Self::DEFAULT_LEVEL)
35 }
36}
37
38pub fn decompress_blocking(input: &[u8], manifest: &ChunkManifest) -> Result<Vec<u8>, CodecError> {
46 if manifest.codec != CodecKind::CpuZstd {
47 return Err(CodecError::CodecMismatch {
48 expected: CodecKind::CpuZstd,
49 got: manifest.codec,
50 });
51 }
52 let allocated_orig_size = validate_decompress_manifest(manifest, input.len())?;
56 use std::io::Read;
57 let limit = manifest.original_size.saturating_add(1024);
58 let mut decoder = zstd::stream::Decoder::new(input).map_err(CodecError::Io)?;
59 let mut buf = Vec::with_capacity(allocated_orig_size.min(DECOMPRESS_BOOTSTRAP_CAPACITY));
65 {
66 let mut limited = (&mut decoder).take(limit);
67 limited.read_to_end(&mut buf).map_err(CodecError::Io)?;
68 if buf.len() as u64 > manifest.original_size {
77 let mut peek = [0u8; 1];
78 let more_available = limited.read(&mut peek).map(|n| n > 0).unwrap_or(false);
79 return Err(CodecError::Io(std::io::Error::other(format!(
80 "zstd decompression bomb detected: produced at least {} bytes \
81 (truncated at cap = manifest.original_size + 1024 = {}); \
82 manifest claimed {}{}",
83 buf.len(),
84 limit,
85 manifest.original_size,
86 if more_available {
87 "; decoder had more bytes available beyond the cap"
88 } else {
89 ""
90 },
91 ))));
92 }
93 }
94 if buf.len() as u64 != manifest.original_size {
95 return Err(CodecError::SizeMismatch {
96 expected: manifest.original_size,
97 got: buf.len() as u64,
98 });
99 }
100 let actual_crc = crc32c::crc32c(&buf);
101 if actual_crc != manifest.crc32c {
102 return Err(CodecError::CrcMismatch {
103 expected: manifest.crc32c,
104 got: actual_crc,
105 });
106 }
107 Ok(buf)
108}
109
110pub fn compress_blocking(input: &[u8], level: i32) -> Result<(Vec<u8>, ChunkManifest), CodecError> {
114 let level = level.clamp(1, 22);
115 let original_size = input.len() as u64;
116 let original_crc = crc32c::crc32c(input);
117 let compressed = zstd::stream::encode_all(input, level).map_err(CodecError::Io)?;
118 Ok((
119 compressed.clone(),
120 ChunkManifest {
121 codec: CodecKind::CpuZstd,
122 original_size,
123 compressed_size: compressed.len() as u64,
124 crc32c: original_crc,
125 },
126 ))
127}
128
129#[async_trait::async_trait]
130impl Codec for CpuZstd {
131 fn kind(&self) -> CodecKind {
132 CodecKind::CpuZstd
133 }
134
135 async fn compress(&self, input: Bytes) -> Result<(Bytes, ChunkManifest), CodecError> {
136 let level = self.level;
137 let original_size = input.len() as u64;
138 let original_crc = crc32c::crc32c(&input);
139
140 let compressed = tokio::task::spawn_blocking(move || -> std::io::Result<Vec<u8>> {
141 zstd::stream::encode_all(input.as_ref(), level)
142 })
143 .await??;
144
145 let compressed_size = compressed.len() as u64;
146 let manifest = ChunkManifest {
147 codec: CodecKind::CpuZstd,
148 original_size,
149 compressed_size,
150 crc32c: original_crc,
151 };
152 Ok((Bytes::from(compressed), manifest))
153 }
154
155 async fn decompress(
156 &self,
157 input: Bytes,
158 manifest: &ChunkManifest,
159 ) -> Result<Bytes, CodecError> {
160 if manifest.codec != CodecKind::CpuZstd {
161 return Err(CodecError::CodecMismatch {
162 expected: CodecKind::CpuZstd,
163 got: manifest.codec,
164 });
165 }
166 let allocated_orig_size = validate_decompress_manifest(manifest, input.len())?;
173
174 let expected_crc = manifest.crc32c;
175 let expected_orig_size = manifest.original_size;
176
177 let decompressed = tokio::task::spawn_blocking(move || -> std::io::Result<Vec<u8>> {
182 use std::io::Read;
183 let limit = expected_orig_size.saturating_add(1024);
187 let mut decoder = zstd::stream::Decoder::new(input.as_ref())?;
188 let mut buf =
191 Vec::with_capacity(allocated_orig_size.min(DECOMPRESS_BOOTSTRAP_CAPACITY));
192 (&mut decoder).take(limit).read_to_end(&mut buf)?;
193 if (buf.len() as u64) > expected_orig_size {
195 return Err(std::io::Error::other(format!(
196 "zstd decompression bomb detected: produced {} bytes, manifest claimed {}",
197 buf.len(),
198 expected_orig_size
199 )));
200 }
201 Ok(buf)
202 })
203 .await??;
204
205 if decompressed.len() as u64 != expected_orig_size {
206 return Err(CodecError::SizeMismatch {
207 expected: expected_orig_size,
208 got: decompressed.len() as u64,
209 });
210 }
211 let actual_crc = crc32c::crc32c(&decompressed);
212 if actual_crc != expected_crc {
213 return Err(CodecError::CrcMismatch {
214 expected: expected_crc,
215 got: actual_crc,
216 });
217 }
218 Ok(Bytes::from(decompressed))
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225
226 #[tokio::test]
227 async fn roundtrip_small() {
228 let codec = CpuZstd::default();
229 let input = Bytes::from_static(b"hello squished s3 hello squished s3 hello squished s3");
230 let (compressed, manifest) = codec.compress(input.clone()).await.unwrap();
231 assert_eq!(manifest.codec, CodecKind::CpuZstd);
233 assert_eq!(manifest.original_size, input.len() as u64);
234 let decompressed = codec.decompress(compressed, &manifest).await.unwrap();
235 assert_eq!(decompressed, input);
236 }
237
238 #[tokio::test]
239 async fn roundtrip_compressible() {
240 let codec = CpuZstd::default();
241 let input = Bytes::from(vec![b'x'; 1024 * 1024]);
243 let (compressed, manifest) = codec.compress(input.clone()).await.unwrap();
244 assert!(
245 compressed.len() < input.len() / 100,
246 "expected zstd to compress 1 MiB of x bytes very well, got {} bytes",
247 compressed.len()
248 );
249 let decompressed = codec.decompress(compressed, &manifest).await.unwrap();
250 assert_eq!(decompressed, input);
251 }
252
253 #[tokio::test]
254 async fn detects_corrupted_compressed_payload() {
255 let codec = CpuZstd::default();
256 let input = Bytes::from(vec![b'x'; 1024]);
257 let (mut compressed, manifest) = codec.compress(input).await.unwrap();
258 let mut buf = compressed.to_vec();
260 if buf.len() > 8 {
261 buf[5] ^= 0xff;
262 }
263 compressed = Bytes::from(buf);
264 let err = codec.decompress(compressed, &manifest).await.unwrap_err();
265 assert!(matches!(
267 err,
268 CodecError::Io(_) | CodecError::CrcMismatch { .. } | CodecError::SizeMismatch { .. }
269 ));
270 }
271
272 #[tokio::test]
273 async fn rejects_codec_mismatch() {
274 let codec = CpuZstd::default();
275 let manifest = ChunkManifest {
276 codec: CodecKind::Passthrough,
277 original_size: 10,
278 compressed_size: 10,
279 crc32c: 0,
280 };
281 let err = codec
282 .decompress(Bytes::from_static(b"0123456789"), &manifest)
283 .await
284 .unwrap_err();
285 assert!(matches!(err, CodecError::CodecMismatch { .. }));
286 }
287
288 #[tokio::test]
292 async fn issue_89_rejects_manifest_over_5gib() {
293 let codec = CpuZstd::default();
294 let body = Bytes::from_static(&[0x00, 0xd1, 0xd1, 0xd1, 0xd1, 0xd1]);
295 let manifest = ChunkManifest {
296 codec: CodecKind::CpuZstd,
297 original_size: crate::MAX_DECOMPRESSED_BYTES + 1,
298 compressed_size: body.len() as u64,
299 crc32c: 0,
300 };
301 let err = codec.decompress(body, &manifest).await.unwrap_err();
302 match err {
303 CodecError::ManifestSizeExceedsLimit { requested, limit } => {
304 assert_eq!(requested, crate::MAX_DECOMPRESSED_BYTES + 1);
305 assert_eq!(limit, crate::MAX_DECOMPRESSED_BYTES);
306 }
307 other => panic!("expected ManifestSizeExceedsLimit, got {other:?}"),
308 }
309 }
310
311 #[tokio::test]
322 async fn issue_89_bootstrap_cap_keeps_4gib_claim_alloc_safe() {
323 let codec = CpuZstd::default();
324 let body = Bytes::from_static(&[0x00, 0xd1, 0xd1, 0xd1, 0xd1, 0xd1]);
325 let manifest = ChunkManifest {
326 codec: CodecKind::CpuZstd,
327 original_size: u32::MAX as u64,
331 compressed_size: body.len() as u64,
332 crc32c: 0,
333 };
334 let err = codec.decompress(body, &manifest).await.unwrap_err();
335 assert!(
339 matches!(err, CodecError::Io(_) | CodecError::SizeMismatch { .. }),
340 "expected Io or SizeMismatch, got {err:?}"
341 );
342 }
343
344 #[test]
347 fn blocking_roundtrip() {
348 let input = b"hello squished s3 hello squished s3 hello squished s3";
349 let (compressed, manifest) = compress_blocking(input, CpuZstd::DEFAULT_LEVEL).unwrap();
350 assert_eq!(manifest.codec, CodecKind::CpuZstd);
351 let decompressed = decompress_blocking(&compressed, &manifest).unwrap();
352 assert_eq!(decompressed, input);
353 }
354
355 #[test]
363 fn issue_89_blocking_rejects_manifest_over_5gib() {
364 let body = &[0x00, 0xd1, 0xd1, 0xd1, 0xd1, 0xd1];
365 let manifest = ChunkManifest {
366 codec: CodecKind::CpuZstd,
367 original_size: crate::MAX_DECOMPRESSED_BYTES + 1,
368 compressed_size: body.len() as u64,
369 crc32c: 0,
370 };
371 let err = decompress_blocking(body, &manifest).unwrap_err();
372 match err {
373 CodecError::ManifestSizeExceedsLimit { requested, limit } => {
374 assert_eq!(requested, crate::MAX_DECOMPRESSED_BYTES + 1);
375 assert_eq!(limit, crate::MAX_DECOMPRESSED_BYTES);
376 }
377 other => panic!("expected ManifestSizeExceedsLimit, got {other:?}"),
378 }
379 }
380
381 #[test]
382 fn issue_89_blocking_bootstrap_cap_keeps_4gib_claim_alloc_safe() {
383 let body = &[0x00, 0xd1, 0xd1, 0xd1, 0xd1, 0xd1];
384 let manifest = ChunkManifest {
385 codec: CodecKind::CpuZstd,
386 original_size: u32::MAX as u64,
387 compressed_size: body.len() as u64,
388 crc32c: 0,
389 };
390 let err = decompress_blocking(body, &manifest).unwrap_err();
391 assert!(
392 matches!(err, CodecError::Io(_) | CodecError::SizeMismatch { .. }),
393 "expected Io or SizeMismatch, got {err:?}"
394 );
395 }
396}