volo_grpc/codec/
compression.rs

1//! These codes are copied from `tonic/src/codec/compression.rs` and may be modified by us.
2
3use std::io;
4
5#[cfg(feature = "compress")]
6use bytes::BufMut;
7use bytes::{Buf, BytesMut};
8#[cfg(feature = "compress")]
9pub use flate2::Compression as Level;
10#[cfg(feature = "gzip")]
11use flate2::bufread::{GzDecoder, GzEncoder};
12#[cfg(feature = "zlib")]
13use flate2::bufread::{ZlibDecoder, ZlibEncoder};
14use http::HeaderValue;
15
16use super::BUFFER_SIZE;
17#[cfg(feature = "compress")]
18use crate::Status;
19
20pub const ENCODING_HEADER: &str = "grpc-encoding";
21pub const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding";
22#[cfg(feature = "compress")]
23const DEFAULT_LEVEL: Level = Level::new(6);
24
25/// The compression encodings volo supports.
26#[derive(Clone, Copy, Debug)]
27pub enum CompressionEncoding {
28    Identity,
29    #[cfg(feature = "gzip")]
30    Gzip(Option<GzipConfig>),
31    #[cfg(feature = "zlib")]
32    Zlib(Option<ZlibConfig>),
33    #[cfg(feature = "zstd")]
34    Zstd(Option<ZstdConfig>),
35}
36
37impl PartialEq for CompressionEncoding {
38    fn eq(&self, other: &Self) -> bool {
39        match (self, other) {
40            #[cfg(feature = "gzip")]
41            (Self::Gzip(_), Self::Gzip(_)) => true,
42            #[cfg(feature = "zlib")]
43            (Self::Zlib(_), Self::Zlib(_)) => true,
44            (Self::Identity, Self::Identity) => true,
45            #[cfg(feature = "zstd")]
46            (Self::Zstd(_), Self::Zstd(_)) => true,
47            #[cfg(feature = "compress")]
48            _ => false,
49        }
50    }
51}
52
53#[derive(Debug, Clone, Copy)]
54#[cfg(feature = "gzip")]
55pub struct GzipConfig {
56    pub level: Level,
57}
58
59#[cfg(feature = "gzip")]
60impl Default for GzipConfig {
61    fn default() -> Self {
62        Self {
63            level: DEFAULT_LEVEL,
64        }
65    }
66}
67
68#[derive(Debug, Clone, Copy)]
69#[cfg(feature = "zlib")]
70pub struct ZlibConfig {
71    pub level: Level,
72}
73
74#[cfg(feature = "zlib")]
75impl Default for ZlibConfig {
76    fn default() -> Self {
77        Self {
78            level: DEFAULT_LEVEL,
79        }
80    }
81}
82
83#[derive(Debug, Clone, Copy)]
84#[cfg(feature = "zstd")]
85pub struct ZstdConfig {
86    pub level: Level,
87}
88
89#[cfg(feature = "zstd")]
90impl Default for ZstdConfig {
91    fn default() -> Self {
92        Self {
93            level: DEFAULT_LEVEL,
94        }
95    }
96}
97
98/// compose multiple compression encodings to a [HeaderValue]
99pub fn compose_encodings(encodings: &[CompressionEncoding]) -> HeaderValue {
100    let encodings = encodings
101        .iter()
102        .map(|item| match item {
103            // TODO: gzip-6 @https://grpc.github.io/grpc/core/md_doc_compression.html#autotoc_md59
104            #[cfg(feature = "gzip")]
105            CompressionEncoding::Gzip(_) => "gzip",
106            #[cfg(feature = "zlib")]
107            CompressionEncoding::Zlib(_) => "zlib",
108            #[cfg(feature = "zstd")]
109            CompressionEncoding::Zstd(_) => "zstd",
110            CompressionEncoding::Identity => "identity",
111        })
112        .collect::<Vec<&'static str>>();
113    // encodings.push("identity");
114
115    HeaderValue::from_str(encodings.join(",").as_str()).unwrap()
116}
117
118#[cfg(feature = "compress")]
119fn is_enabled(encoding: CompressionEncoding, encodings: &[CompressionEncoding]) -> bool {
120    encodings.contains(&encoding)
121}
122
123impl CompressionEncoding {
124    /// make the compression encoding into a [HeaderValue]
125    pub fn into_header_value(self) -> HeaderValue {
126        match self {
127            #[cfg(feature = "gzip")]
128            CompressionEncoding::Gzip(_) => HeaderValue::from_static("gzip"),
129            #[cfg(feature = "zlib")]
130            CompressionEncoding::Zlib(_) => HeaderValue::from_static("zlib"),
131            #[cfg(feature = "zstd")]
132            CompressionEncoding::Zstd(_) => HeaderValue::from_static("zstd"),
133            CompressionEncoding::Identity => HeaderValue::from_static("identity"),
134        }
135    }
136
137    /// make the compression encodings into a [HeaderValue],and the encodings uses a `,` as
138    /// separator
139    pub fn into_accept_encoding_header_value(
140        self,
141        encodings: &[CompressionEncoding],
142    ) -> Option<HeaderValue> {
143        if self.is_enabled() {
144            Some(compose_encodings(encodings))
145        } else {
146            None
147        }
148    }
149
150    /// Based on the `grpc-accept-encoding` header, adaptive picking an encoding to use.
151    #[cfg(feature = "compress")]
152    pub fn from_accept_encoding_header(
153        headers: &http::HeaderMap,
154        config: &Option<Vec<Self>>,
155    ) -> Option<Self> {
156        if let Some(available_encodings) = config {
157            let header_value = headers.get(ACCEPT_ENCODING_HEADER)?;
158            let header_value_str = header_value.to_str().ok()?;
159
160            header_value_str
161                .split(',')
162                .map(|s| s.trim())
163                .find_map(|encoding| match encoding {
164                    #[cfg(feature = "gzip")]
165                    "gzip" => available_encodings.iter().find_map(|item| {
166                        if item.is_gzip_enabled() {
167                            Some(*item)
168                        } else {
169                            None
170                        }
171                    }),
172                    #[cfg(feature = "zlib")]
173                    "zlib" => available_encodings.iter().find_map(|item| {
174                        if item.is_zlib_enabled() {
175                            Some(*item)
176                        } else {
177                            None
178                        }
179                    }),
180                    #[cfg(feature = "zstd")]
181                    "zstd" => available_encodings.iter().find_map(|item| {
182                        if item.is_zstd_enabled() {
183                            Some(*item)
184                        } else {
185                            None
186                        }
187                    }),
188                    _ => None,
189                })
190        } else {
191            None
192        }
193    }
194
195    /// Get the value of `grpc-encoding` header. Returns an error if the encoding isn't supported.
196    #[allow(clippy::result_large_err)]
197    #[cfg(feature = "compress")]
198    pub fn from_encoding_header(
199        headers: &http::HeaderMap,
200        config: &Option<Vec<Self>>,
201    ) -> Result<Option<Self>, Status> {
202        if let Some(encodings) = config {
203            let header_value = if let Some(header_value) = headers.get(ENCODING_HEADER) {
204                header_value
205            } else {
206                return Ok(None);
207            };
208
209            match header_value.to_str()? {
210                #[cfg(feature = "gzip")]
211                "gzip" if is_enabled(Self::Gzip(None), encodings) => Ok(Some(Self::Gzip(None))),
212                #[cfg(feature = "zlib")]
213                "zlib" if is_enabled(Self::Zlib(None), encodings) => Ok(Some(Self::Zlib(None))),
214                #[cfg(feature = "zstd")]
215                "zstd" if is_enabled(Self::Zstd(None), encodings) => Ok(Some(Self::Zstd(None))),
216                "identity" => Ok(None),
217                other => {
218                    let status = Status::unimplemented(format!(
219                        "Content is compressed with `{other}` which isn't supported"
220                    ));
221                    Err(status)
222                }
223            }
224        } else {
225            Ok(None)
226        }
227    }
228
229    /// please use it only for Compression type is insignificant, otherwise you will have a
230    /// duplicate pattern-matching problem
231    #[cfg(feature = "compress")]
232    pub fn level(self) -> Level {
233        match self {
234            #[cfg(feature = "gzip")]
235            CompressionEncoding::Gzip(Some(config)) => config.level,
236            #[cfg(feature = "zlib")]
237            CompressionEncoding::Zlib(Some(config)) => config.level,
238            #[cfg(feature = "zstd")]
239            CompressionEncoding::Zstd(Some(config)) => config.level,
240            _ => DEFAULT_LEVEL,
241        }
242    }
243
244    #[cfg(feature = "gzip")]
245    const fn is_gzip_enabled(&self) -> bool {
246        matches!(self, CompressionEncoding::Gzip(_))
247    }
248
249    #[cfg(feature = "zlib")]
250    const fn is_zlib_enabled(&self) -> bool {
251        matches!(self, CompressionEncoding::Zlib(_))
252    }
253
254    #[cfg(feature = "zstd")]
255    const fn is_zstd_enabled(&self) -> bool {
256        matches!(self, CompressionEncoding::Zstd(_))
257    }
258
259    const fn is_enabled(&self) -> bool {
260        #[allow(unreachable_patterns)]
261        match self {
262            #[cfg(feature = "gzip")]
263            CompressionEncoding::Gzip(_) => true,
264            #[cfg(feature = "zlib")]
265            CompressionEncoding::Zlib(_) => true,
266            #[cfg(feature = "zstd")]
267            CompressionEncoding::Zstd(_) => true,
268            _ => false,
269        }
270    }
271}
272
273/// Compress `len` bytes from `src_buf` into `dest_buf`.
274pub(crate) fn compress(
275    encoding: CompressionEncoding,
276    src_buf: &mut BytesMut,
277    dest_buf: &mut BytesMut,
278) -> Result<(), io::Error> {
279    let len = src_buf.len();
280    let capacity = ((len / BUFFER_SIZE) + 1) * BUFFER_SIZE;
281
282    dest_buf.reserve(capacity);
283
284    match encoding {
285        #[cfg(feature = "gzip")]
286        CompressionEncoding::Gzip(Some(config)) => {
287            let mut gz_encoder = GzEncoder::new(&src_buf[0..len], config.level);
288            io::copy(&mut gz_encoder, &mut dest_buf.writer())?;
289        }
290        #[cfg(feature = "zlib")]
291        CompressionEncoding::Zlib(Some(config)) => {
292            let mut zlib_encoder = ZlibEncoder::new(&src_buf[0..len], config.level);
293            io::copy(&mut zlib_encoder, &mut dest_buf.writer())?;
294        }
295        #[cfg(feature = "zstd")]
296        CompressionEncoding::Zstd(Some(config)) => {
297            let level = config.level.level();
298            let zstd_level = if level == 0 {
299                zstd::DEFAULT_COMPRESSION_LEVEL
300            } else {
301                level as i32
302            };
303            let mut zstd_encoder = zstd::Encoder::new(dest_buf.writer(), zstd_level)?;
304            io::copy(&mut &src_buf[0..len], &mut zstd_encoder)?;
305            zstd_encoder.finish()?;
306        }
307        _ => {}
308    };
309
310    src_buf.advance(len);
311    Ok(())
312}
313
314/// Decompress `len` bytes from `src_buf` into `dest_buf`.
315pub(crate) fn decompress(
316    encoding: CompressionEncoding,
317    src_buf: &mut BytesMut,
318    dest_buf: &mut BytesMut,
319) -> Result<(), io::Error> {
320    let len = src_buf.len();
321    let estimate_decompressed_len = len * 2;
322    let capacity = ((estimate_decompressed_len / BUFFER_SIZE) + 1) * BUFFER_SIZE;
323
324    dest_buf.reserve(capacity);
325
326    match encoding {
327        #[cfg(feature = "gzip")]
328        CompressionEncoding::Gzip(_) => {
329            let mut gz_decoder = GzDecoder::new(&src_buf[0..len]);
330            io::copy(&mut gz_decoder, &mut dest_buf.writer())?;
331        }
332        #[cfg(feature = "zlib")]
333        CompressionEncoding::Zlib(_) => {
334            let mut zlib_decoder = ZlibDecoder::new(&src_buf[0..len]);
335            io::copy(&mut zlib_decoder, &mut dest_buf.writer())?;
336        }
337        #[cfg(feature = "zstd")]
338        CompressionEncoding::Zstd(_) => {
339            let mut zstd_decoder = zstd::Decoder::new(&src_buf[0..len])?;
340            io::copy(&mut zstd_decoder, &mut dest_buf.writer())?;
341        }
342        _ => {}
343    };
344
345    src_buf.advance(len);
346    Ok(())
347}
348
349#[cfg(test)]
350mod tests {
351    use bytes::BytesMut;
352
353    #[cfg(feature = "gzip")]
354    use crate::codec::compression::GzipConfig;
355    #[cfg(feature = "compress")]
356    use crate::codec::compression::Level;
357    #[cfg(feature = "zlib")]
358    use crate::codec::compression::ZlibConfig;
359    #[cfg(feature = "zstd")]
360    use crate::codec::compression::ZstdConfig;
361    use crate::codec::{
362        BUFFER_SIZE,
363        compression::{CompressionEncoding, compress, decompress},
364    };
365
366    #[test]
367    fn test_consistency_for_compression() {
368        let mut src = BytesMut::with_capacity(BUFFER_SIZE);
369        let mut compress_buf = BytesMut::new();
370        let mut de_data = BytesMut::with_capacity(BUFFER_SIZE);
371        let test_data = &b"test compression"[..];
372        src.extend_from_slice(test_data);
373
374        let encodings = [
375            #[cfg(feature = "gzip")]
376            CompressionEncoding::Gzip(Some(GzipConfig {
377                level: Level::fast(),
378            })),
379            #[cfg(feature = "zlib")]
380            CompressionEncoding::Zlib(Some(ZlibConfig {
381                level: Level::fast(),
382            })),
383            #[cfg(feature = "zstd")]
384            CompressionEncoding::Zstd(Some(ZstdConfig {
385                level: Level::new(3),
386            })),
387            CompressionEncoding::Identity,
388        ];
389
390        for encoding in encodings {
391            compress_buf.clear();
392            compress(encoding, &mut src, &mut compress_buf).expect("compress failed:");
393            decompress(encoding, &mut compress_buf, &mut de_data).expect("decompress failed:");
394            assert_eq!(test_data, de_data.as_ref());
395        }
396    }
397}