Skip to main content

s4_codec/
cpu_zstd.rs

1//! CPU zstd backend — GPU 非搭載環境向け究極の fallback、および test bed。
2//!
3//! - `zstd` crate (`zstd-safe` + `zstd-sys`、Apache-2.0 OR MIT) を使った直球実装
4//! - 圧縮処理は CPU 重量級なので `tokio::task::spawn_blocking` で別スレッドへ逃がす
5//! - production では nvCOMP より遅いが、機能 / wire 互換 test の常設レーンとして必須
6
7use bytes::Bytes;
8
9use crate::{ChunkManifest, Codec, CodecError, CodecKind};
10
11/// CPU zstd codec。`level` は 1..=22 (zstd-22 は最大圧縮率、時間は長い)。
12///
13/// S4 default は `3` (zstd の通常 default、速度と圧縮率のバランス)。
14#[derive(Debug, Clone)]
15pub struct CpuZstd {
16    level: i32,
17}
18
19impl CpuZstd {
20    pub const DEFAULT_LEVEL: i32 = 3;
21
22    pub fn new(level: i32) -> Self {
23        Self {
24            level: level.clamp(1, 22),
25        }
26    }
27}
28
29impl Default for CpuZstd {
30    fn default() -> Self {
31        Self::new(Self::DEFAULT_LEVEL)
32    }
33}
34
35/// Sync, runtime-free decompress used by `s4-codec-wasm` (browser / WASM has
36/// no tokio runtime and no `spawn_blocking`). Same checks as the trait
37/// implementation: codec/size match, decompression-bomb cap at
38/// `manifest.original_size + 1024`, crc32c verify after.
39///
40/// Kept in this module (and not duplicated in the wasm crate) so the bomb
41/// limit + size + crc rules stay defined exactly once.
42pub fn decompress_blocking(input: &[u8], manifest: &ChunkManifest) -> Result<Vec<u8>, CodecError> {
43    if manifest.codec != CodecKind::CpuZstd {
44        return Err(CodecError::CodecMismatch {
45            expected: CodecKind::CpuZstd,
46            got: manifest.codec,
47        });
48    }
49    if input.len() as u64 != manifest.compressed_size {
50        return Err(CodecError::SizeMismatch {
51            expected: manifest.compressed_size,
52            got: input.len() as u64,
53        });
54    }
55    use std::io::Read;
56    let limit = manifest.original_size.saturating_add(1024);
57    let mut decoder = zstd::stream::Decoder::new(input).map_err(CodecError::Io)?;
58    let mut buf = Vec::with_capacity(manifest.original_size as usize);
59    (&mut decoder)
60        .take(limit)
61        .read_to_end(&mut buf)
62        .map_err(CodecError::Io)?;
63    if (buf.len() as u64) > manifest.original_size {
64        return Err(CodecError::Io(std::io::Error::other(format!(
65            "zstd decompression bomb detected: produced {} bytes, manifest claimed {}",
66            buf.len(),
67            manifest.original_size
68        ))));
69    }
70    if buf.len() as u64 != manifest.original_size {
71        return Err(CodecError::SizeMismatch {
72            expected: manifest.original_size,
73            got: buf.len() as u64,
74        });
75    }
76    let actual_crc = crc32c::crc32c(&buf);
77    if actual_crc != manifest.crc32c {
78        return Err(CodecError::CrcMismatch {
79            expected: manifest.crc32c,
80            got: actual_crc,
81        });
82    }
83    Ok(buf)
84}
85
86/// Sync compress sibling of `decompress_blocking`. Provided for symmetry — the
87/// browser side rarely compresses (it's read-only), but having both halves
88/// keeps the API explainable and useful for offline tooling.
89pub fn compress_blocking(input: &[u8], level: i32) -> Result<(Vec<u8>, ChunkManifest), CodecError> {
90    let level = level.clamp(1, 22);
91    let original_size = input.len() as u64;
92    let original_crc = crc32c::crc32c(input);
93    let compressed = zstd::stream::encode_all(input, level).map_err(CodecError::Io)?;
94    Ok((
95        compressed.clone(),
96        ChunkManifest {
97            codec: CodecKind::CpuZstd,
98            original_size,
99            compressed_size: compressed.len() as u64,
100            crc32c: original_crc,
101        },
102    ))
103}
104
105#[async_trait::async_trait]
106impl Codec for CpuZstd {
107    fn kind(&self) -> CodecKind {
108        CodecKind::CpuZstd
109    }
110
111    async fn compress(&self, input: Bytes) -> Result<(Bytes, ChunkManifest), CodecError> {
112        let level = self.level;
113        let original_size = input.len() as u64;
114        let original_crc = crc32c::crc32c(&input);
115
116        let compressed = tokio::task::spawn_blocking(move || -> std::io::Result<Vec<u8>> {
117            zstd::stream::encode_all(input.as_ref(), level)
118        })
119        .await??;
120
121        let compressed_size = compressed.len() as u64;
122        let manifest = ChunkManifest {
123            codec: CodecKind::CpuZstd,
124            original_size,
125            compressed_size,
126            crc32c: original_crc,
127        };
128        Ok((Bytes::from(compressed), manifest))
129    }
130
131    async fn decompress(
132        &self,
133        input: Bytes,
134        manifest: &ChunkManifest,
135    ) -> Result<Bytes, CodecError> {
136        if manifest.codec != CodecKind::CpuZstd {
137            return Err(CodecError::CodecMismatch {
138                expected: CodecKind::CpuZstd,
139                got: manifest.codec,
140            });
141        }
142        if input.len() as u64 != manifest.compressed_size {
143            return Err(CodecError::SizeMismatch {
144                expected: manifest.compressed_size,
145                got: input.len() as u64,
146            });
147        }
148
149        let expected_crc = manifest.crc32c;
150        let expected_orig_size = manifest.original_size;
151
152        // **Zstd decompression bomb hardening**: 信頼できない入力 (改ざんされた
153        // sidecar / S3 上で bit flip / 攻撃者操作) で `decode_all` が無制限に
154        // 出力を伸ばすと OOM するので、`expected_orig_size + small margin` で
155        // 上限を hard-cap する。Decoder + Read::take パターンで実装。
156        let decompressed = tokio::task::spawn_blocking(move || -> std::io::Result<Vec<u8>> {
157            use std::io::Read;
158            // 1 KiB margin: zstd の internal buffer flush で多少 overshoot しても
159            // 検出できる余地を残す。expected_orig_size + margin を超えたら
160            // bomb 認定して error にする
161            let limit = expected_orig_size.saturating_add(1024);
162            let mut decoder = zstd::stream::Decoder::new(input.as_ref())?;
163            let mut buf = Vec::with_capacity(expected_orig_size as usize);
164            (&mut decoder).take(limit).read_to_end(&mut buf)?;
165            // limit 以上を消費したかチェック (= bomb)
166            if (buf.len() as u64) > expected_orig_size {
167                return Err(std::io::Error::other(format!(
168                    "zstd decompression bomb detected: produced {} bytes, manifest claimed {}",
169                    buf.len(),
170                    expected_orig_size
171                )));
172            }
173            Ok(buf)
174        })
175        .await??;
176
177        if decompressed.len() as u64 != expected_orig_size {
178            return Err(CodecError::SizeMismatch {
179                expected: expected_orig_size,
180                got: decompressed.len() as u64,
181            });
182        }
183        let actual_crc = crc32c::crc32c(&decompressed);
184        if actual_crc != expected_crc {
185            return Err(CodecError::CrcMismatch {
186                expected: expected_crc,
187                got: actual_crc,
188            });
189        }
190        Ok(Bytes::from(decompressed))
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    #[tokio::test]
199    async fn roundtrip_small() {
200        let codec = CpuZstd::default();
201        let input = Bytes::from_static(b"hello squished s3 hello squished s3 hello squished s3");
202        let (compressed, manifest) = codec.compress(input.clone()).await.unwrap();
203        // small string compresses small but not necessarily smaller
204        assert_eq!(manifest.codec, CodecKind::CpuZstd);
205        assert_eq!(manifest.original_size, input.len() as u64);
206        let decompressed = codec.decompress(compressed, &manifest).await.unwrap();
207        assert_eq!(decompressed, input);
208    }
209
210    #[tokio::test]
211    async fn roundtrip_compressible() {
212        let codec = CpuZstd::default();
213        // highly compressible payload (1 MB of repeated pattern)
214        let input = Bytes::from(vec![b'x'; 1024 * 1024]);
215        let (compressed, manifest) = codec.compress(input.clone()).await.unwrap();
216        assert!(
217            compressed.len() < input.len() / 100,
218            "expected zstd to compress 1 MiB of x bytes very well, got {} bytes",
219            compressed.len()
220        );
221        let decompressed = codec.decompress(compressed, &manifest).await.unwrap();
222        assert_eq!(decompressed, input);
223    }
224
225    #[tokio::test]
226    async fn detects_corrupted_compressed_payload() {
227        let codec = CpuZstd::default();
228        let input = Bytes::from(vec![b'x'; 1024]);
229        let (mut compressed, manifest) = codec.compress(input).await.unwrap();
230        // flip a byte mid-payload
231        let mut buf = compressed.to_vec();
232        if buf.len() > 8 {
233            buf[5] ^= 0xff;
234        }
235        compressed = Bytes::from(buf);
236        let err = codec.decompress(compressed, &manifest).await.unwrap_err();
237        // either zstd refuses to decode (Io) or crc check catches it (CrcMismatch)
238        assert!(matches!(
239            err,
240            CodecError::Io(_) | CodecError::CrcMismatch { .. } | CodecError::SizeMismatch { .. }
241        ));
242    }
243
244    #[tokio::test]
245    async fn rejects_codec_mismatch() {
246        let codec = CpuZstd::default();
247        let manifest = ChunkManifest {
248            codec: CodecKind::Passthrough,
249            original_size: 10,
250            compressed_size: 10,
251            crc32c: 0,
252        };
253        let err = codec
254            .decompress(Bytes::from_static(b"0123456789"), &manifest)
255            .await
256            .unwrap_err();
257        assert!(matches!(err, CodecError::CodecMismatch { .. }));
258    }
259
260    /// `decompress_blocking` (used by `s4-codec-wasm`) round-trips through
261    /// `compress_blocking` with the same checks the async path applies.
262    #[test]
263    fn blocking_roundtrip() {
264        let input = b"hello squished s3 hello squished s3 hello squished s3";
265        let (compressed, manifest) = compress_blocking(input, CpuZstd::DEFAULT_LEVEL).unwrap();
266        assert_eq!(manifest.codec, CodecKind::CpuZstd);
267        let decompressed = decompress_blocking(&compressed, &manifest).unwrap();
268        assert_eq!(decompressed, input);
269    }
270}