Skip to main content

s4_codec/
cpu_zstd.rs

1//! CPU zstd backend — GPU 非搭載環境向け究極の fallback、および test bed。
2//!
3//! - `zstd` crate (`zstd-safe` + `zstd-sys`、Apache-2.0 OR MIT) を使った直球実装
4//! - 圧縮処理は CPU 重量級なので `tokio::task::spawn_blocking` で別スレッドへ逃がす
5//! - production では nvCOMP より遅いが、機能 / wire 互換 test の常設レーンとして必須
6
7use bytes::Bytes;
8
9use crate::{ChunkManifest, Codec, CodecError, CodecKind};
10
11/// CPU zstd codec。`level` は 1..=22 (zstd-22 は最大圧縮率、時間は長い)。
12///
13/// S4 default は `3` (zstd の通常 default、速度と圧縮率のバランス)。
14#[derive(Debug, Clone)]
15pub struct CpuZstd {
16    level: i32,
17}
18
19impl CpuZstd {
20    pub const DEFAULT_LEVEL: i32 = 3;
21
22    pub fn new(level: i32) -> Self {
23        Self {
24            level: level.clamp(1, 22),
25        }
26    }
27}
28
29impl Default for CpuZstd {
30    fn default() -> Self {
31        Self::new(Self::DEFAULT_LEVEL)
32    }
33}
34
35#[async_trait::async_trait]
36impl Codec for CpuZstd {
37    fn kind(&self) -> CodecKind {
38        CodecKind::CpuZstd
39    }
40
41    async fn compress(&self, input: Bytes) -> Result<(Bytes, ChunkManifest), CodecError> {
42        let level = self.level;
43        let original_size = input.len() as u64;
44        let original_crc = crc32c::crc32c(&input);
45
46        let compressed = tokio::task::spawn_blocking(move || -> std::io::Result<Vec<u8>> {
47            zstd::stream::encode_all(input.as_ref(), level)
48        })
49        .await??;
50
51        let compressed_size = compressed.len() as u64;
52        let manifest = ChunkManifest {
53            codec: CodecKind::CpuZstd,
54            original_size,
55            compressed_size,
56            crc32c: original_crc,
57        };
58        Ok((Bytes::from(compressed), manifest))
59    }
60
61    async fn decompress(
62        &self,
63        input: Bytes,
64        manifest: &ChunkManifest,
65    ) -> Result<Bytes, CodecError> {
66        if manifest.codec != CodecKind::CpuZstd {
67            return Err(CodecError::CodecMismatch {
68                expected: CodecKind::CpuZstd,
69                got: manifest.codec,
70            });
71        }
72        if input.len() as u64 != manifest.compressed_size {
73            return Err(CodecError::SizeMismatch {
74                expected: manifest.compressed_size,
75                got: input.len() as u64,
76            });
77        }
78
79        let expected_crc = manifest.crc32c;
80        let expected_orig_size = manifest.original_size;
81
82        // **Zstd decompression bomb hardening**: 信頼できない入力 (改ざんされた
83        // sidecar / S3 上で bit flip / 攻撃者操作) で `decode_all` が無制限に
84        // 出力を伸ばすと OOM するので、`expected_orig_size + small margin` で
85        // 上限を hard-cap する。Decoder + Read::take パターンで実装。
86        let decompressed = tokio::task::spawn_blocking(move || -> std::io::Result<Vec<u8>> {
87            use std::io::Read;
88            // 1 KiB margin: zstd の internal buffer flush で多少 overshoot しても
89            // 検出できる余地を残す。expected_orig_size + margin を超えたら
90            // bomb 認定して error にする
91            let limit = expected_orig_size.saturating_add(1024);
92            let mut decoder = zstd::stream::Decoder::new(input.as_ref())?;
93            let mut buf = Vec::with_capacity(expected_orig_size as usize);
94            (&mut decoder).take(limit).read_to_end(&mut buf)?;
95            // limit 以上を消費したかチェック (= bomb)
96            if (buf.len() as u64) > expected_orig_size {
97                return Err(std::io::Error::other(format!(
98                    "zstd decompression bomb detected: produced {} bytes, manifest claimed {}",
99                    buf.len(),
100                    expected_orig_size
101                )));
102            }
103            Ok(buf)
104        })
105        .await??;
106
107        if decompressed.len() as u64 != expected_orig_size {
108            return Err(CodecError::SizeMismatch {
109                expected: expected_orig_size,
110                got: decompressed.len() as u64,
111            });
112        }
113        let actual_crc = crc32c::crc32c(&decompressed);
114        if actual_crc != expected_crc {
115            return Err(CodecError::CrcMismatch {
116                expected: expected_crc,
117                got: actual_crc,
118            });
119        }
120        Ok(Bytes::from(decompressed))
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127
128    #[tokio::test]
129    async fn roundtrip_small() {
130        let codec = CpuZstd::default();
131        let input = Bytes::from_static(b"hello squished s3 hello squished s3 hello squished s3");
132        let (compressed, manifest) = codec.compress(input.clone()).await.unwrap();
133        // small string compresses small but not necessarily smaller
134        assert_eq!(manifest.codec, CodecKind::CpuZstd);
135        assert_eq!(manifest.original_size, input.len() as u64);
136        let decompressed = codec.decompress(compressed, &manifest).await.unwrap();
137        assert_eq!(decompressed, input);
138    }
139
140    #[tokio::test]
141    async fn roundtrip_compressible() {
142        let codec = CpuZstd::default();
143        // highly compressible payload (1 MB of repeated pattern)
144        let input = Bytes::from(vec![b'x'; 1024 * 1024]);
145        let (compressed, manifest) = codec.compress(input.clone()).await.unwrap();
146        assert!(
147            compressed.len() < input.len() / 100,
148            "expected zstd to compress 1 MiB of x bytes very well, got {} bytes",
149            compressed.len()
150        );
151        let decompressed = codec.decompress(compressed, &manifest).await.unwrap();
152        assert_eq!(decompressed, input);
153    }
154
155    #[tokio::test]
156    async fn detects_corrupted_compressed_payload() {
157        let codec = CpuZstd::default();
158        let input = Bytes::from(vec![b'x'; 1024]);
159        let (mut compressed, manifest) = codec.compress(input).await.unwrap();
160        // flip a byte mid-payload
161        let mut buf = compressed.to_vec();
162        if buf.len() > 8 {
163            buf[5] ^= 0xff;
164        }
165        compressed = Bytes::from(buf);
166        let err = codec.decompress(compressed, &manifest).await.unwrap_err();
167        // either zstd refuses to decode (Io) or crc check catches it (CrcMismatch)
168        assert!(matches!(
169            err,
170            CodecError::Io(_) | CodecError::CrcMismatch { .. } | CodecError::SizeMismatch { .. }
171        ));
172    }
173
174    #[tokio::test]
175    async fn rejects_codec_mismatch() {
176        let codec = CpuZstd::default();
177        let manifest = ChunkManifest {
178            codec: CodecKind::Passthrough,
179            original_size: 10,
180            compressed_size: 10,
181            crc32c: 0,
182        };
183        let err = codec
184            .decompress(Bytes::from_static(b"0123456789"), &manifest)
185            .await
186            .unwrap_err();
187        assert!(matches!(err, CodecError::CodecMismatch { .. }));
188    }
189}