Skip to main content

clickhouse_native_client/
compression.rs

1//! LZ4 and ZSTD block compression for the ClickHouse native protocol.
2//!
3//! Each compressed frame has the following wire format:
4//!
5//! ```text
6//! [16 bytes CityHash128 checksum][1 byte method][4 bytes compressed_size][4 bytes uncompressed_size][compressed data]
7//! ```
8//!
9//! The `compress` function produces a complete frame (checksum + header +
10//! data). The `decompress` function accepts a complete frame and returns the
11//! uncompressed payload.
12
13use crate::{
14    protocol::CompressionMethod,
15    Error,
16    Result,
17};
18use bytes::{
19    Buf,
20    BufMut,
21    Bytes,
22    BytesMut,
23};
24use cityhash_rs::cityhash_102_128;
25
26/// Compression header size (9 bytes: 1 byte method + 4 bytes compressed + 4
27/// bytes uncompressed)
28const HEADER_SIZE: usize = 9;
29
30/// Checksum size (16 bytes for CityHash128)
31const CHECKSUM_SIZE: usize = 16;
32
33/// Compression method byte values (from ClickHouse CompressionMethodByte)
34#[repr(u8)]
35enum CompressionMethodByte {
36    None = 0x02,
37    Lz4 = 0x82,
38    Zstd = 0x90,
39}
40
41/// Maximum compressed block size (1GB)
42const MAX_COMPRESSED_SIZE: usize = 0x40000000;
43
44/// Compress data using the specified method.
45///
46/// Returns a complete compressed frame including CityHash128 checksum,
47/// compression header, and compressed payload.
48///
49/// # Errors
50///
51/// Returns `Error::Compression` if the underlying LZ4 or ZSTD encoder fails.
52pub fn compress(method: CompressionMethod, data: &[u8]) -> Result<Bytes> {
53    match method {
54        CompressionMethod::None => {
55            // No compression, but still add header
56            compress_none(data)
57        }
58        CompressionMethod::Lz4 => compress_lz4(data),
59        CompressionMethod::Zstd => compress_zstd(data),
60    }
61}
62
63/// Decompress data (auto-detects compression method from header).
64///
65/// Expects a complete compressed frame: checksum + header + payload.
66/// The compression method is detected from the header byte.
67///
68/// # Errors
69///
70/// Returns `Error::Compression` if:
71/// - The data is too small for the checksum and header.
72/// - The compressed or uncompressed size exceeds 1 GB.
73/// - The compression method byte is unrecognized.
74/// - The underlying LZ4 or ZSTD decoder fails.
75pub fn decompress(data: &[u8]) -> Result<Bytes> {
76    if data.len() < CHECKSUM_SIZE + HEADER_SIZE {
77        return Err(Error::Compression(
78            "Data too small for checksum and compression header".to_string(),
79        ));
80    }
81
82    // Skip checksum (first 16 bytes) - we could verify it but for now we trust
83    // the TCP layer
84    let data_without_checksum = &data[CHECKSUM_SIZE..];
85
86    let method = data_without_checksum[0];
87    let mut reader = &data_without_checksum[1..];
88
89    // Read compressed size (4 bytes) and uncompressed size (4 bytes)
90    let compressed_size = reader.get_u32_le() as usize;
91    let uncompressed_size = reader.get_u32_le() as usize;
92
93    // Validate sizes
94    if compressed_size > MAX_COMPRESSED_SIZE {
95        return Err(Error::Compression(format!(
96            "Compressed size too large: {}",
97            compressed_size
98        )));
99    }
100
101    if uncompressed_size > MAX_COMPRESSED_SIZE {
102        return Err(Error::Compression(format!(
103            "Uncompressed size too large: {}",
104            uncompressed_size
105        )));
106    }
107
108    // The remaining data after header
109    let compressed_data = &data_without_checksum[HEADER_SIZE..];
110
111    match method {
112        0x02 => {
113            // No compression
114            if compressed_data.len() != uncompressed_size {
115                return Err(Error::Compression(format!(
116                    "Uncompressed data size mismatch: expected {}, got {}",
117                    uncompressed_size,
118                    compressed_data.len()
119                )));
120            }
121            Ok(Bytes::copy_from_slice(compressed_data))
122        }
123        0x82 => {
124            // LZ4
125            decompress_lz4(compressed_data, uncompressed_size)
126        }
127        0x90 => {
128            // ZSTD
129            decompress_zstd(compressed_data, uncompressed_size)
130        }
131        _ => Err(Error::Compression(format!(
132            "Unknown compression method: 0x{:02x}",
133            method
134        ))),
135    }
136}
137
138/// Compress using LZ4
139fn compress_lz4(data: &[u8]) -> Result<Bytes> {
140    let max_compressed_size = lz4::block::compress_bound(data.len())?;
141    let mut compressed = vec![0u8; max_compressed_size];
142
143    let compressed_size =
144        lz4::block::compress_to_buffer(data, None, false, &mut compressed)?;
145
146    compressed.truncate(compressed_size);
147
148    // Build header + compressed data
149    let mut header_and_data =
150        BytesMut::with_capacity(HEADER_SIZE + compressed_size);
151
152    // Write header
153    header_and_data.put_u8(CompressionMethodByte::Lz4 as u8);
154    header_and_data.put_u32_le((HEADER_SIZE + compressed_size) as u32); // Total size including header
155    header_and_data.put_u32_le(data.len() as u32); // Uncompressed size
156
157    // Write compressed data
158    header_and_data.put_slice(&compressed);
159
160    // Compute CityHash128 checksum of header + compressed data
161    let checksum = cityhash_102_128(&header_and_data);
162
163    // Build final output with checksum
164    // CityHash128 returns u128, write as (high64, low64) - reverse of typical
165    // order
166    let mut output =
167        BytesMut::with_capacity(CHECKSUM_SIZE + header_and_data.len());
168    output.put_u64_le((checksum >> 64) as u64); // High 64 bits first
169    output.put_u64_le(checksum as u64); // Low 64 bits second
170    output.put_slice(&header_and_data);
171
172    Ok(output.freeze())
173}
174
175/// Decompress LZ4 data
176fn decompress_lz4(data: &[u8], uncompressed_size: usize) -> Result<Bytes> {
177    let decompressed =
178        lz4::block::decompress(data, Some(uncompressed_size as i32))?;
179
180    if decompressed.len() != uncompressed_size {
181        return Err(Error::Compression(format!(
182            "LZ4 decompression size mismatch: expected {}, got {}",
183            uncompressed_size,
184            decompressed.len()
185        )));
186    }
187
188    Ok(Bytes::from(decompressed))
189}
190
191/// Compress using ZSTD
192fn compress_zstd(data: &[u8]) -> Result<Bytes> {
193    let compressed = zstd::bulk::compress(data, 3) // Compression level 3
194        .map_err(|e| {
195            Error::Compression(format!("ZSTD compression failed: {}", e))
196        })?;
197
198    // Build header + compressed data
199    let mut header_and_data =
200        BytesMut::with_capacity(HEADER_SIZE + compressed.len());
201
202    // Write header
203    header_and_data.put_u8(CompressionMethodByte::Zstd as u8);
204    header_and_data.put_u32_le((HEADER_SIZE + compressed.len()) as u32); // Total size including header
205    header_and_data.put_u32_le(data.len() as u32); // Uncompressed size
206
207    // Write compressed data
208    header_and_data.put_slice(&compressed);
209
210    // Compute CityHash128 checksum of header + compressed data
211    let checksum = cityhash_102_128(&header_and_data);
212
213    // Build final output with checksum
214    // CityHash128 returns u128, write as (high64, low64) - reverse of typical
215    // order
216    let mut output =
217        BytesMut::with_capacity(CHECKSUM_SIZE + header_and_data.len());
218    output.put_u64_le((checksum >> 64) as u64); // High 64 bits first
219    output.put_u64_le(checksum as u64); // Low 64 bits second
220    output.put_slice(&header_and_data);
221
222    Ok(output.freeze())
223}
224
225/// Decompress ZSTD data
226fn decompress_zstd(data: &[u8], uncompressed_size: usize) -> Result<Bytes> {
227    let decompressed = zstd::bulk::decompress(data, uncompressed_size)
228        .map_err(|e| {
229            Error::Compression(format!("ZSTD decompression failed: {}", e))
230        })?;
231
232    if decompressed.len() != uncompressed_size {
233        return Err(Error::Compression(format!(
234            "ZSTD decompression size mismatch: expected {}, got {}",
235            uncompressed_size,
236            decompressed.len()
237        )));
238    }
239
240    Ok(Bytes::from(decompressed))
241}
242
243/// No compression (just adds header)
244fn compress_none(data: &[u8]) -> Result<Bytes> {
245    // Build header + data
246    let mut header_and_data =
247        BytesMut::with_capacity(HEADER_SIZE + data.len());
248
249    // Write header
250    header_and_data.put_u8(CompressionMethodByte::None as u8);
251    header_and_data.put_u32_le((HEADER_SIZE + data.len()) as u32); // Total size
252    header_and_data.put_u32_le(data.len() as u32); // Uncompressed size (same as total)
253
254    // Write uncompressed data
255    header_and_data.put_slice(data);
256
257    // Compute CityHash128 checksum of header + data
258    let checksum = cityhash_102_128(&header_and_data);
259
260    // Build final output with checksum
261    let mut output =
262        BytesMut::with_capacity(CHECKSUM_SIZE + header_and_data.len());
263    output.put_u128_le(checksum); // CityHash128 as little-endian u128
264    output.put_slice(&header_and_data);
265
266    Ok(output.freeze())
267}
268
269#[cfg(test)]
270#[cfg_attr(coverage_nightly, coverage(off))]
271mod tests {
272    use super::*;
273
274    #[test]
275    fn test_compress_decompress_none() {
276        let original = b"Hello, ClickHouse!";
277
278        let compressed = compress(CompressionMethod::None, original).unwrap();
279        let decompressed = decompress(&compressed).unwrap();
280
281        assert_eq!(&decompressed[..], original);
282    }
283
284    #[test]
285    fn test_compress_decompress_lz4() {
286        let original = b"Hello, ClickHouse! ".repeat(100);
287
288        let compressed = compress(CompressionMethod::Lz4, &original).unwrap();
289        let decompressed = decompress(&compressed).unwrap();
290
291        assert_eq!(&decompressed[..], &original[..]);
292
293        // Should achieve some compression
294        assert!(compressed.len() < original.len());
295    }
296
297    #[test]
298    fn test_compress_decompress_zstd() {
299        let original =
300            b"ClickHouse is a fast open-source column-oriented database"
301                .repeat(50);
302
303        let compressed = compress(CompressionMethod::Zstd, &original).unwrap();
304        let decompressed = decompress(&compressed).unwrap();
305
306        assert_eq!(&decompressed[..], &original[..]);
307
308        // Should achieve good compression
309        assert!(compressed.len() < original.len());
310    }
311
312    #[test]
313    fn test_empty_data() {
314        let original = b"";
315
316        // Should work with empty data
317        let compressed = compress(CompressionMethod::Lz4, original).unwrap();
318        let decompressed = decompress(&compressed).unwrap();
319
320        assert_eq!(&decompressed[..], original);
321    }
322
323    #[test]
324    fn test_large_data_lz4() {
325        // Test with larger data
326        let original = vec![42u8; 100_000];
327
328        let compressed = compress(CompressionMethod::Lz4, &original).unwrap();
329        let decompressed = decompress(&compressed).unwrap();
330
331        assert_eq!(&decompressed[..], &original[..]);
332
333        // Should compress very well (all same byte)
334        assert!(compressed.len() < original.len() / 10);
335    }
336
337    #[test]
338    fn test_invalid_compression_method() {
339        let mut bad_data = vec![0xFFu8; 20]; // Invalid method byte
340        bad_data[1..5].copy_from_slice(&20u32.to_le_bytes()); // compressed size
341        bad_data[5..9].copy_from_slice(&10u32.to_le_bytes()); // uncompressed size
342
343        let result = decompress(&bad_data);
344        assert!(result.is_err());
345    }
346
347    #[test]
348    fn test_header_too_small() {
349        let bad_data = vec![0x82, 1, 2, 3]; // Only 4 bytes, need 9
350
351        let result = decompress(&bad_data);
352        assert!(result.is_err());
353    }
354}