poem_grpc/
compression.rs

1use std::{io::Result as IoResult, str::FromStr};
2
3use http::HeaderMap;
4
5use crate::{Code, Metadata, Status};
6
7/// The compression encodings.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub enum CompressionEncoding {
10    /// gzip
11    #[cfg(feature = "gzip")]
12    #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))]
13    GZIP,
14    /// deflate
15    #[cfg(feature = "deflate")]
16    #[cfg_attr(docsrs, doc(cfg(feature = "deflate")))]
17    DEFLATE,
18    /// brotli
19    #[cfg(feature = "brotli")]
20    #[cfg_attr(docsrs, doc(cfg(feature = "brotli")))]
21    BROTLI,
22    /// zstd
23    #[cfg(feature = "zstd")]
24    #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
25    ZSTD,
26}
27
28impl FromStr for CompressionEncoding {
29    type Err = ();
30
31    #[inline]
32    fn from_str(s: &str) -> Result<Self, Self::Err> {
33        match s {
34            #[cfg(feature = "gzip")]
35            "gzip" => Ok(CompressionEncoding::GZIP),
36            #[cfg(feature = "deflate")]
37            "deflate" => Ok(CompressionEncoding::DEFLATE),
38            #[cfg(feature = "brotli")]
39            "br" => Ok(CompressionEncoding::BROTLI),
40            #[cfg(feature = "zstd")]
41            "zstd" => Ok(CompressionEncoding::ZSTD),
42            _ => Err(()),
43        }
44    }
45}
46
47impl CompressionEncoding {
48    /// Returns the encoding name.
49    #[allow(unreachable_patterns)]
50    pub fn as_str(&self) -> &'static str {
51        match self {
52            #[cfg(feature = "gzip")]
53            CompressionEncoding::GZIP => "gzip",
54            #[cfg(feature = "deflate")]
55            CompressionEncoding::DEFLATE => "deflate",
56            #[cfg(feature = "brotli")]
57            CompressionEncoding::BROTLI => "br",
58            #[cfg(feature = "zstd")]
59            CompressionEncoding::ZSTD => "zstd",
60            _ => unreachable!(),
61        }
62    }
63
64    #[allow(
65        unreachable_code,
66        unused_imports,
67        unused_mut,
68        unused_variables,
69        unreachable_patterns
70    )]
71    pub(crate) async fn encode(&self, data: &[u8]) -> IoResult<Vec<u8>> {
72        use tokio::io::AsyncReadExt;
73
74        let mut buf = Vec::new();
75
76        match self {
77            #[cfg(feature = "gzip")]
78            CompressionEncoding::GZIP => {
79                async_compression::tokio::bufread::GzipEncoder::new(data)
80                    .read_to_end(&mut buf)
81                    .await?;
82            }
83            #[cfg(feature = "deflate")]
84            CompressionEncoding::DEFLATE => {
85                async_compression::tokio::bufread::DeflateEncoder::new(data)
86                    .read_to_end(&mut buf)
87                    .await?;
88            }
89            #[cfg(feature = "brotli")]
90            CompressionEncoding::BROTLI => {
91                async_compression::tokio::bufread::BrotliEncoder::new(data)
92                    .read_to_end(&mut buf)
93                    .await?;
94            }
95            #[cfg(feature = "zstd")]
96            CompressionEncoding::ZSTD => {
97                async_compression::tokio::bufread::ZstdEncoder::new(data)
98                    .read_to_end(&mut buf)
99                    .await?;
100            }
101            _ => unreachable!(),
102        }
103
104        Ok(buf)
105    }
106
107    #[allow(
108        unreachable_code,
109        unused_imports,
110        unused_mut,
111        unused_variables,
112        unreachable_patterns
113    )]
114    pub(crate) async fn decode(&self, data: &[u8]) -> IoResult<Vec<u8>> {
115        use tokio::io::AsyncReadExt;
116
117        let mut buf = Vec::new();
118
119        match self {
120            #[cfg(feature = "gzip")]
121            CompressionEncoding::GZIP => {
122                async_compression::tokio::bufread::GzipDecoder::new(data)
123                    .read_to_end(&mut buf)
124                    .await?;
125            }
126            #[cfg(feature = "deflate")]
127            CompressionEncoding::DEFLATE => {
128                async_compression::tokio::bufread::DeflateDecoder::new(data)
129                    .read_to_end(&mut buf)
130                    .await?;
131            }
132            #[cfg(feature = "brotli")]
133            CompressionEncoding::BROTLI => {
134                async_compression::tokio::bufread::BrotliDecoder::new(data)
135                    .read_to_end(&mut buf)
136                    .await?;
137            }
138            #[cfg(feature = "zstd")]
139            CompressionEncoding::ZSTD => {
140                async_compression::tokio::bufread::ZstdDecoder::new(data)
141                    .read_to_end(&mut buf)
142                    .await?;
143            }
144            _ => unreachable!(),
145        }
146
147        Ok(buf)
148    }
149}
150
151fn unimplemented(accept_compressed: &[CompressionEncoding]) -> Status {
152    let mut md = Metadata::new();
153    let mut accept_encoding = String::new();
154    let mut iter = accept_compressed.iter();
155    if let Some(encoding) = iter.next() {
156        accept_encoding.push_str(encoding.as_str());
157    }
158    for encoding in iter {
159        accept_encoding.push_str(", ");
160        accept_encoding.push_str(encoding.as_str());
161    }
162    md.append("grpc-accept-encoding", accept_encoding);
163    Status::new(Code::Unimplemented)
164        .with_metadata(md)
165        .with_message("unsupported encoding")
166}
167
168#[allow(clippy::result_large_err)]
169pub(crate) fn get_incoming_encodings(
170    headers: &HeaderMap,
171    accept_compressed: &[CompressionEncoding],
172) -> Result<Option<CompressionEncoding>, Status> {
173    let Some(value) = headers.get("grpc-encoding") else {
174        return Ok(None);
175    };
176    let Some(encoding) = value
177        .to_str()
178        .ok()
179        .and_then(|value| value.parse::<CompressionEncoding>().ok())
180    else {
181        return Err(unimplemented(accept_compressed));
182    };
183    if !accept_compressed.contains(&encoding) {
184        return Err(unimplemented(accept_compressed));
185    }
186    Ok(Some(encoding))
187}