Skip to main content

connectrpc_axum_core/
codec.rs

1//! Compression codec trait and implementations.
2//!
3//! This module provides the [`Codec`] trait for per-message compression
4//! and implementations for common algorithms:
5//! - [`GzipCodec`]: Gzip compression (requires `compression-gzip` feature)
6//! - [`DeflateCodec`]: Deflate compression (requires `compression-deflate` feature)
7//! - [`BrotliCodec`]: Brotli compression (requires `compression-br` feature)
8//! - [`ZstdCodec`]: Zstd compression (requires `compression-zstd` feature)
9
10use bytes::Bytes;
11use std::io;
12use std::sync::Arc;
13
14#[cfg(any(
15    feature = "compression-gzip-stream",
16    feature = "compression-deflate-stream",
17    feature = "compression-br-stream",
18    feature = "compression-zstd-stream"
19))]
20use std::io::{Read, Write};
21
22#[cfg(feature = "compression-gzip-stream")]
23use flate2::Compression as GzipLevel;
24#[cfg(feature = "compression-gzip-stream")]
25use flate2::read::GzDecoder;
26#[cfg(feature = "compression-gzip-stream")]
27use flate2::write::GzEncoder;
28
29/// Codec trait for per-message (envelope) compression.
30///
31/// Used for streaming Connect RPCs where each message is individually compressed.
32/// HTTP body compression for unary RPCs is typically handled by middleware.
33///
34/// # Example
35///
36/// ```ignore
37/// use connectrpc_axum_core::Codec;
38/// use bytes::Bytes;
39/// use std::io;
40///
41/// struct Lz4Codec;
42///
43/// impl Codec for Lz4Codec {
44///     fn name(&self) -> &'static str { "lz4" }
45///
46///     fn compress(&self, data: &[u8]) -> io::Result<Bytes> {
47///         // ... lz4 compression
48///     }
49///
50///     fn decompress(&self, data: &[u8]) -> io::Result<Bytes> {
51///         // ... lz4 decompression
52///     }
53/// }
54/// ```
55pub trait Codec: Send + Sync + 'static {
56    /// The encoding name for HTTP headers (e.g., "gzip", "zstd", "br").
57    fn name(&self) -> &'static str;
58
59    /// Compress data.
60    fn compress(&self, data: &[u8]) -> io::Result<Bytes>;
61
62    /// Decompress data.
63    fn decompress(&self, data: &[u8]) -> io::Result<Bytes>;
64}
65
66/// A boxed codec for type-erased storage.
67///
68/// Use `Option<BoxedCodec>` where `None` represents identity (no compression).
69#[derive(Clone)]
70pub struct BoxedCodec(Arc<dyn Codec>);
71
72impl BoxedCodec {
73    /// Create a new boxed codec.
74    pub fn new<C: Codec>(codec: C) -> Self {
75        BoxedCodec(Arc::new(codec))
76    }
77
78    /// Get the codec name for HTTP headers.
79    pub fn name(&self) -> &'static str {
80        self.0.name()
81    }
82
83    /// Compress data.
84    pub fn compress(&self, data: &[u8]) -> io::Result<Bytes> {
85        self.0.compress(data)
86    }
87
88    /// Decompress data.
89    pub fn decompress(&self, data: &[u8]) -> io::Result<Bytes> {
90        self.0.decompress(data)
91    }
92}
93
94impl std::fmt::Debug for BoxedCodec {
95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        f.debug_tuple("BoxedCodec").field(&self.name()).finish()
97    }
98}
99
100/// Gzip codec using flate2.
101///
102/// Requires the `compression-gzip` feature.
103#[cfg(feature = "compression-gzip-stream")]
104#[derive(Debug, Clone, Copy)]
105pub struct GzipCodec {
106    /// Compression level (0-9). Default is 6.
107    pub level: u32,
108}
109
110#[cfg(feature = "compression-gzip-stream")]
111impl Default for GzipCodec {
112    fn default() -> Self {
113        Self { level: 6 }
114    }
115}
116
117#[cfg(feature = "compression-gzip-stream")]
118impl GzipCodec {
119    /// Create a new GzipCodec with the specified compression level.
120    ///
121    /// Level ranges from 0 (no compression) to 9 (best compression).
122    pub fn with_level(level: u32) -> Self {
123        Self {
124            level: level.min(9),
125        }
126    }
127}
128
129#[cfg(feature = "compression-gzip-stream")]
130impl Codec for GzipCodec {
131    fn name(&self) -> &'static str {
132        "gzip"
133    }
134
135    fn compress(&self, data: &[u8]) -> io::Result<Bytes> {
136        let mut encoder = GzEncoder::new(Vec::new(), GzipLevel::new(self.level));
137        encoder.write_all(data)?;
138        Ok(Bytes::from(encoder.finish()?))
139    }
140
141    fn decompress(&self, data: &[u8]) -> io::Result<Bytes> {
142        let mut decoder = GzDecoder::new(data);
143        let mut decompressed = Vec::new();
144        decoder.read_to_end(&mut decompressed)?;
145        Ok(Bytes::from(decompressed))
146    }
147}
148
149/// Identity codec (no compression).
150///
151/// This codec passes data through unchanged.
152#[derive(Debug, Clone, Copy, Default)]
153pub struct IdentityCodec;
154
155impl Codec for IdentityCodec {
156    fn name(&self) -> &'static str {
157        "identity"
158    }
159
160    fn compress(&self, data: &[u8]) -> io::Result<Bytes> {
161        Ok(Bytes::copy_from_slice(data))
162    }
163
164    fn decompress(&self, data: &[u8]) -> io::Result<Bytes> {
165        Ok(Bytes::copy_from_slice(data))
166    }
167}
168
169/// Deflate codec using flate2 (zlib format).
170///
171/// Note: HTTP "deflate" Content-Encoding uses zlib format (RFC 1950),
172/// not raw DEFLATE (RFC 1951).
173///
174/// Requires the `compression-deflate` feature.
175#[cfg(feature = "compression-deflate-stream")]
176#[derive(Debug, Clone, Copy)]
177pub struct DeflateCodec {
178    /// Compression level (0-9). Default is 6.
179    pub level: u32,
180}
181
182#[cfg(feature = "compression-deflate-stream")]
183impl Default for DeflateCodec {
184    fn default() -> Self {
185        Self { level: 6 }
186    }
187}
188
189#[cfg(feature = "compression-deflate-stream")]
190impl DeflateCodec {
191    /// Create a new DeflateCodec with the specified compression level.
192    ///
193    /// Level ranges from 0 (no compression) to 9 (best compression).
194    pub fn with_level(level: u32) -> Self {
195        Self {
196            level: level.min(9),
197        }
198    }
199}
200
201#[cfg(feature = "compression-deflate-stream")]
202impl Codec for DeflateCodec {
203    fn name(&self) -> &'static str {
204        "deflate"
205    }
206
207    fn compress(&self, data: &[u8]) -> io::Result<Bytes> {
208        use flate2::write::ZlibEncoder;
209        let mut encoder = ZlibEncoder::new(Vec::new(), flate2::Compression::new(self.level));
210        encoder.write_all(data)?;
211        Ok(Bytes::from(encoder.finish()?))
212    }
213
214    fn decompress(&self, data: &[u8]) -> io::Result<Bytes> {
215        use flate2::read::ZlibDecoder;
216        let mut decoder = ZlibDecoder::new(data);
217        let mut decompressed = Vec::new();
218        decoder.read_to_end(&mut decompressed)?;
219        Ok(Bytes::from(decompressed))
220    }
221}
222
223/// Brotli codec.
224///
225/// Requires the `compression-br` feature.
226#[cfg(feature = "compression-br-stream")]
227#[derive(Debug, Clone, Copy)]
228pub struct BrotliCodec {
229    /// Compression quality (0-11). Default is 4.
230    pub quality: u32,
231}
232
233#[cfg(feature = "compression-br-stream")]
234impl Default for BrotliCodec {
235    fn default() -> Self {
236        Self { quality: 4 }
237    }
238}
239
240#[cfg(feature = "compression-br-stream")]
241impl BrotliCodec {
242    /// Create a new BrotliCodec with the specified quality level.
243    ///
244    /// Quality ranges from 0 (fastest) to 11 (best compression).
245    pub fn with_quality(quality: u32) -> Self {
246        Self {
247            quality: quality.min(11),
248        }
249    }
250}
251
252#[cfg(feature = "compression-br-stream")]
253impl Codec for BrotliCodec {
254    fn name(&self) -> &'static str {
255        "br"
256    }
257
258    fn compress(&self, data: &[u8]) -> io::Result<Bytes> {
259        use brotli::enc::BrotliEncoderParams;
260        let mut output = Vec::new();
261        let params = BrotliEncoderParams {
262            quality: self.quality as i32,
263            ..Default::default()
264        };
265        brotli::enc::BrotliCompress(&mut std::io::Cursor::new(data), &mut output, &params)?;
266        Ok(Bytes::from(output))
267    }
268
269    fn decompress(&self, data: &[u8]) -> io::Result<Bytes> {
270        let mut output = Vec::new();
271        brotli::BrotliDecompress(&mut std::io::Cursor::new(data), &mut output)?;
272        Ok(Bytes::from(output))
273    }
274}
275
276/// Zstd codec.
277///
278/// Requires the `compression-zstd` feature.
279#[cfg(feature = "compression-zstd-stream")]
280#[derive(Debug, Clone, Copy)]
281pub struct ZstdCodec {
282    /// Compression level (1-22). Default is 3.
283    pub level: i32,
284}
285
286#[cfg(feature = "compression-zstd-stream")]
287impl Default for ZstdCodec {
288    fn default() -> Self {
289        Self { level: 3 }
290    }
291}
292
293#[cfg(feature = "compression-zstd-stream")]
294impl ZstdCodec {
295    /// Create a new ZstdCodec with the specified compression level.
296    ///
297    /// Level ranges from 1 (fastest) to 22 (best compression).
298    pub fn with_level(level: i32) -> Self {
299        Self {
300            level: level.clamp(1, 22),
301        }
302    }
303}
304
305#[cfg(feature = "compression-zstd-stream")]
306impl Codec for ZstdCodec {
307    fn name(&self) -> &'static str {
308        "zstd"
309    }
310
311    fn compress(&self, data: &[u8]) -> io::Result<Bytes> {
312        let compressed = zstd::bulk::compress(data, self.level)
313            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
314        Ok(Bytes::from(compressed))
315    }
316
317    fn decompress(&self, data: &[u8]) -> io::Result<Bytes> {
318        let mut decoder = zstd::Decoder::new(data)?;
319        let mut decompressed = Vec::new();
320        decoder.read_to_end(&mut decompressed)?;
321        Ok(Bytes::from(decompressed))
322    }
323}
324
325/// Compress bytes using the specified codec.
326///
327/// If `codec` is `None`, returns the input unchanged (identity).
328pub fn compress_bytes(bytes: Bytes, codec: Option<&BoxedCodec>) -> io::Result<Bytes> {
329    match codec {
330        None => Ok(bytes), // identity: zero-copy passthrough
331        Some(c) => c.compress(&bytes),
332    }
333}
334
335/// Decompress bytes using the specified codec.
336///
337/// If `codec` is `None`, returns the input unchanged (identity).
338pub fn decompress_bytes(bytes: Bytes, codec: Option<&BoxedCodec>) -> io::Result<Bytes> {
339    match codec {
340        None => Ok(bytes), // identity: zero-copy passthrough
341        Some(c) => c.decompress(&bytes),
342    }
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348
349    #[cfg(feature = "compression-gzip-stream")]
350    #[test]
351    fn test_gzip_codec_compress_decompress() {
352        let codec = GzipCodec::default();
353        assert_eq!(codec.name(), "gzip");
354
355        let original = b"Hello, World! This is a test message.";
356        let compressed = codec.compress(original).unwrap();
357        assert_ne!(&compressed[..], &original[..]);
358
359        let decompressed = codec.decompress(&compressed).unwrap();
360        assert_eq!(&decompressed[..], &original[..]);
361    }
362
363    #[cfg(feature = "compression-gzip-stream")]
364    #[test]
365    fn test_gzip_codec_with_level() {
366        let codec = GzipCodec::with_level(9);
367        assert_eq!(codec.level, 9);
368
369        let original = b"Hello, World! This is a test message.";
370        let compressed = codec.compress(original).unwrap();
371        let decompressed = codec.decompress(&compressed).unwrap();
372        assert_eq!(&decompressed[..], &original[..]);
373    }
374
375    #[test]
376    fn test_identity_codec() {
377        let codec = IdentityCodec;
378        assert_eq!(codec.name(), "identity");
379
380        let original = b"Hello, World!";
381        let compressed = codec.compress(original).unwrap();
382        assert_eq!(&compressed[..], &original[..]);
383
384        let decompressed = codec.decompress(&compressed).unwrap();
385        assert_eq!(&decompressed[..], &original[..]);
386    }
387
388    #[cfg(feature = "compression-gzip-stream")]
389    #[test]
390    fn test_boxed_codec() {
391        let codec = BoxedCodec::new(GzipCodec::default());
392        assert_eq!(codec.name(), "gzip");
393
394        let original = b"Hello, World! This is a test message.";
395        let compressed = codec.compress(original).unwrap();
396        assert_ne!(&compressed[..], &original[..]);
397
398        let decompressed = codec.decompress(&compressed).unwrap();
399        assert_eq!(&decompressed[..], &original[..]);
400    }
401
402    #[cfg(feature = "compression-gzip-stream")]
403    #[test]
404    fn test_compress_decompress_bytes_with_codec() {
405        let codec = BoxedCodec::new(GzipCodec::default());
406        let original = Bytes::from_static(b"Hello, World! This is a test message.");
407
408        let compressed = compress_bytes(original.clone(), Some(&codec)).unwrap();
409        assert_ne!(compressed, original);
410
411        let decompressed = decompress_bytes(compressed, Some(&codec)).unwrap();
412        assert_eq!(decompressed, original);
413    }
414
415    #[test]
416    fn test_compress_decompress_bytes_identity() {
417        let original = Bytes::from_static(b"Hello, World!");
418
419        let compressed = compress_bytes(original.clone(), None).unwrap();
420        assert_eq!(compressed, original);
421
422        let decompressed = decompress_bytes(compressed, None).unwrap();
423        assert_eq!(decompressed, original);
424    }
425
426    #[cfg(feature = "compression-gzip-stream")]
427    #[test]
428    fn test_decompress_invalid_gzip() {
429        let codec = BoxedCodec::new(GzipCodec::default());
430        let invalid = b"not valid gzip data";
431        let result = codec.decompress(invalid);
432        assert!(result.is_err());
433    }
434
435    #[cfg(feature = "compression-gzip-stream")]
436    #[test]
437    fn test_boxed_codec_debug() {
438        let codec = BoxedCodec::new(GzipCodec::default());
439        let debug_str = format!("{:?}", codec);
440        assert!(debug_str.contains("BoxedCodec"));
441        assert!(debug_str.contains("gzip"));
442    }
443
444    #[cfg(feature = "compression-br-stream")]
445    #[test]
446    fn test_brotli_codec_compress_decompress() {
447        let codec = BrotliCodec::default();
448        assert_eq!(codec.name(), "br");
449
450        let original = b"Hello, World! This is a test message for brotli.";
451        let compressed = codec.compress(original).unwrap();
452        assert_ne!(&compressed[..], &original[..]);
453
454        let decompressed = codec.decompress(&compressed).unwrap();
455        assert_eq!(&decompressed[..], &original[..]);
456    }
457
458    #[cfg(feature = "compression-br-stream")]
459    #[test]
460    fn test_brotli_codec_with_quality() {
461        let codec = BrotliCodec::with_quality(11);
462        assert_eq!(codec.quality, 11);
463
464        let original = b"Hello, World! This is a test message.";
465        let compressed = codec.compress(original).unwrap();
466        let decompressed = codec.decompress(&compressed).unwrap();
467        assert_eq!(&decompressed[..], &original[..]);
468    }
469
470    #[cfg(feature = "compression-zstd-stream")]
471    #[test]
472    fn test_zstd_codec_compress_decompress() {
473        let codec = ZstdCodec::default();
474        assert_eq!(codec.name(), "zstd");
475
476        let original = b"Hello, World! This is a test message for zstd.";
477        let compressed = codec.compress(original).unwrap();
478        assert_ne!(&compressed[..], &original[..]);
479
480        let decompressed = codec.decompress(&compressed).unwrap();
481        assert_eq!(&decompressed[..], &original[..]);
482    }
483
484    #[cfg(feature = "compression-zstd-stream")]
485    #[test]
486    fn test_zstd_codec_with_level() {
487        let codec = ZstdCodec::with_level(19);
488        assert_eq!(codec.level, 19);
489
490        let original = b"Hello, World! This is a test message.";
491        let compressed = codec.compress(original).unwrap();
492        let decompressed = codec.decompress(&compressed).unwrap();
493        assert_eq!(&decompressed[..], &original[..]);
494    }
495}