informalsystems_tonic/codec/
compression.rs

1use super::encode::BUFFER_SIZE;
2use crate::{metadata::MetadataValue, Status};
3use bytes::{Buf, BufMut, BytesMut};
4use flate2::read::{GzDecoder, GzEncoder};
5use std::fmt;
6
7pub(crate) const ENCODING_HEADER: &str = "grpc-encoding";
8pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding";
9
10/// Struct used to configure which encodings are enabled on a server or channel.
11#[derive(Debug, Default, Clone, Copy)]
12pub struct EnabledCompressionEncodings {
13    pub(crate) gzip: bool,
14}
15
16impl EnabledCompressionEncodings {
17    /// Check if `gzip` compression is enabled.
18    pub fn gzip(self) -> bool {
19        self.gzip
20    }
21
22    /// Enable `gzip` compression.
23    pub fn enable_gzip(&mut self) {
24        self.gzip = true;
25    }
26
27    pub(crate) fn into_accept_encoding_header_value(self) -> Option<http::HeaderValue> {
28        let Self { gzip } = self;
29        if gzip {
30            Some(http::HeaderValue::from_static("gzip,identity"))
31        } else {
32            None
33        }
34    }
35}
36
37/// The compression encodings Tonic supports.
38#[derive(Clone, Copy, Debug, PartialEq, Eq)]
39#[non_exhaustive]
40pub enum CompressionEncoding {
41    #[allow(missing_docs)]
42    Gzip,
43}
44
45impl CompressionEncoding {
46    /// Based on the `grpc-accept-encoding` header, pick an encoding to use.
47    pub(crate) fn from_accept_encoding_header(
48        map: &http::HeaderMap,
49        enabled_encodings: EnabledCompressionEncodings,
50    ) -> Option<Self> {
51        let header_value = map.get(ACCEPT_ENCODING_HEADER)?;
52        let header_value_str = header_value.to_str().ok()?;
53
54        let EnabledCompressionEncodings { gzip } = enabled_encodings;
55
56        split_by_comma(header_value_str).find_map(|value| match value {
57            "gzip" if gzip => Some(CompressionEncoding::Gzip),
58            _ => None,
59        })
60    }
61
62    /// Get the value of `grpc-encoding` header. Returns an error if the encoding isn't supported.
63    pub(crate) fn from_encoding_header(
64        map: &http::HeaderMap,
65        enabled_encodings: EnabledCompressionEncodings,
66    ) -> Result<Option<Self>, Status> {
67        let header_value = if let Some(value) = map.get(ENCODING_HEADER) {
68            value
69        } else {
70            return Ok(None);
71        };
72
73        let header_value_str = if let Ok(value) = header_value.to_str() {
74            value
75        } else {
76            return Ok(None);
77        };
78
79        let EnabledCompressionEncodings { gzip } = enabled_encodings;
80
81        match header_value_str {
82            "gzip" if gzip => Ok(Some(CompressionEncoding::Gzip)),
83            other => {
84                let mut status = Status::unimplemented(format!(
85                    "Content is compressed with `{}` which isn't supported",
86                    other
87                ));
88
89                let header_value = enabled_encodings
90                    .into_accept_encoding_header_value()
91                    .map(MetadataValue::unchecked_from_header_value)
92                    .unwrap_or_else(|| MetadataValue::from_static("identity"));
93                status
94                    .metadata_mut()
95                    .insert(ACCEPT_ENCODING_HEADER, header_value);
96
97                Err(status)
98            }
99        }
100    }
101
102    pub(crate) fn into_header_value(self) -> http::HeaderValue {
103        match self {
104            CompressionEncoding::Gzip => http::HeaderValue::from_static("gzip"),
105        }
106    }
107}
108
109impl fmt::Display for CompressionEncoding {
110    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111        match self {
112            CompressionEncoding::Gzip => write!(f, "gzip"),
113        }
114    }
115}
116
117fn split_by_comma(s: &str) -> impl Iterator<Item = &str> {
118    s.trim().split(',').map(|s| s.trim())
119}
120
121/// Compress `len` bytes from `decompressed_buf` into `out_buf`.
122pub(crate) fn compress(
123    encoding: CompressionEncoding,
124    decompressed_buf: &mut BytesMut,
125    out_buf: &mut BytesMut,
126    len: usize,
127) -> Result<(), std::io::Error> {
128    let capacity = ((len / BUFFER_SIZE) + 1) * BUFFER_SIZE;
129    out_buf.reserve(capacity);
130
131    match encoding {
132        CompressionEncoding::Gzip => {
133            let mut gzip_encoder = GzEncoder::new(
134                &decompressed_buf[0..len],
135                // FIXME: support customizing the compression level
136                flate2::Compression::new(6),
137            );
138            let mut out_writer = out_buf.writer();
139
140            std::io::copy(&mut gzip_encoder, &mut out_writer)?;
141        }
142    }
143
144    decompressed_buf.advance(len);
145
146    Ok(())
147}
148
149/// Decompress `len` bytes from `compressed_buf` into `out_buf`.
150pub(crate) fn decompress(
151    encoding: CompressionEncoding,
152    compressed_buf: &mut BytesMut,
153    out_buf: &mut BytesMut,
154    len: usize,
155) -> Result<(), std::io::Error> {
156    let estimate_decompressed_len = len * 2;
157    let capacity = ((estimate_decompressed_len / BUFFER_SIZE) + 1) * BUFFER_SIZE;
158    out_buf.reserve(capacity);
159
160    match encoding {
161        CompressionEncoding::Gzip => {
162            let mut gzip_decoder = GzDecoder::new(&compressed_buf[0..len]);
163            let mut out_writer = out_buf.writer();
164
165            std::io::copy(&mut gzip_decoder, &mut out_writer)?;
166        }
167    }
168
169    compressed_buf.advance(len);
170
171    Ok(())
172}
173
174#[derive(Debug, Clone, Copy, PartialEq, Eq)]
175pub(crate) enum SingleMessageCompressionOverride {
176    /// Inherit whatever compression is already configured. If the stream is compressed this
177    /// message will also be configured.
178    ///
179    /// This is the default.
180    Inherit,
181    /// Don't compress this message, even if compression is enabled on the stream.
182    Disable,
183}
184
185impl Default for SingleMessageCompressionOverride {
186    fn default() -> Self {
187        Self::Inherit
188    }
189}