1use std::{io::Result as IoResult, str::FromStr};
2
3use http::HeaderMap;
4
5use crate::{Code, Metadata, Status};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub enum CompressionEncoding {
10 #[cfg(feature = "gzip")]
12 #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))]
13 GZIP,
14 #[cfg(feature = "deflate")]
16 #[cfg_attr(docsrs, doc(cfg(feature = "deflate")))]
17 DEFLATE,
18 #[cfg(feature = "brotli")]
20 #[cfg_attr(docsrs, doc(cfg(feature = "brotli")))]
21 BROTLI,
22 #[cfg(feature = "zstd")]
24 #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
25 ZSTD,
26}
27
28impl FromStr for CompressionEncoding {
29 type Err = ();
30
31 #[inline]
32 fn from_str(s: &str) -> Result<Self, Self::Err> {
33 match s {
34 #[cfg(feature = "gzip")]
35 "gzip" => Ok(CompressionEncoding::GZIP),
36 #[cfg(feature = "deflate")]
37 "deflate" => Ok(CompressionEncoding::DEFLATE),
38 #[cfg(feature = "brotli")]
39 "br" => Ok(CompressionEncoding::BROTLI),
40 #[cfg(feature = "zstd")]
41 "zstd" => Ok(CompressionEncoding::ZSTD),
42 _ => Err(()),
43 }
44 }
45}
46
47impl CompressionEncoding {
48 #[allow(unreachable_patterns)]
50 pub fn as_str(&self) -> &'static str {
51 match self {
52 #[cfg(feature = "gzip")]
53 CompressionEncoding::GZIP => "gzip",
54 #[cfg(feature = "deflate")]
55 CompressionEncoding::DEFLATE => "deflate",
56 #[cfg(feature = "brotli")]
57 CompressionEncoding::BROTLI => "br",
58 #[cfg(feature = "zstd")]
59 CompressionEncoding::ZSTD => "zstd",
60 _ => unreachable!(),
61 }
62 }
63
64 #[allow(
65 unreachable_code,
66 unused_imports,
67 unused_mut,
68 unused_variables,
69 unreachable_patterns
70 )]
71 pub(crate) async fn encode(&self, data: &[u8]) -> IoResult<Vec<u8>> {
72 use tokio::io::AsyncReadExt;
73
74 let mut buf = Vec::new();
75
76 match self {
77 #[cfg(feature = "gzip")]
78 CompressionEncoding::GZIP => {
79 async_compression::tokio::bufread::GzipEncoder::new(data)
80 .read_to_end(&mut buf)
81 .await?;
82 }
83 #[cfg(feature = "deflate")]
84 CompressionEncoding::DEFLATE => {
85 async_compression::tokio::bufread::DeflateEncoder::new(data)
86 .read_to_end(&mut buf)
87 .await?;
88 }
89 #[cfg(feature = "brotli")]
90 CompressionEncoding::BROTLI => {
91 async_compression::tokio::bufread::BrotliEncoder::new(data)
92 .read_to_end(&mut buf)
93 .await?;
94 }
95 #[cfg(feature = "zstd")]
96 CompressionEncoding::ZSTD => {
97 async_compression::tokio::bufread::ZstdEncoder::new(data)
98 .read_to_end(&mut buf)
99 .await?;
100 }
101 _ => unreachable!(),
102 }
103
104 Ok(buf)
105 }
106
107 #[allow(
108 unreachable_code,
109 unused_imports,
110 unused_mut,
111 unused_variables,
112 unreachable_patterns
113 )]
114 pub(crate) async fn decode(&self, data: &[u8]) -> IoResult<Vec<u8>> {
115 use tokio::io::AsyncReadExt;
116
117 let mut buf = Vec::new();
118
119 match self {
120 #[cfg(feature = "gzip")]
121 CompressionEncoding::GZIP => {
122 async_compression::tokio::bufread::GzipDecoder::new(data)
123 .read_to_end(&mut buf)
124 .await?;
125 }
126 #[cfg(feature = "deflate")]
127 CompressionEncoding::DEFLATE => {
128 async_compression::tokio::bufread::DeflateDecoder::new(data)
129 .read_to_end(&mut buf)
130 .await?;
131 }
132 #[cfg(feature = "brotli")]
133 CompressionEncoding::BROTLI => {
134 async_compression::tokio::bufread::BrotliDecoder::new(data)
135 .read_to_end(&mut buf)
136 .await?;
137 }
138 #[cfg(feature = "zstd")]
139 CompressionEncoding::ZSTD => {
140 async_compression::tokio::bufread::ZstdDecoder::new(data)
141 .read_to_end(&mut buf)
142 .await?;
143 }
144 _ => unreachable!(),
145 }
146
147 Ok(buf)
148 }
149}
150
151fn unimplemented(accept_compressed: &[CompressionEncoding]) -> Status {
152 let mut md = Metadata::new();
153 let mut accept_encoding = String::new();
154 let mut iter = accept_compressed.iter();
155 if let Some(encoding) = iter.next() {
156 accept_encoding.push_str(encoding.as_str());
157 }
158 for encoding in iter {
159 accept_encoding.push_str(", ");
160 accept_encoding.push_str(encoding.as_str());
161 }
162 md.append("grpc-accept-encoding", accept_encoding);
163 Status::new(Code::Unimplemented)
164 .with_metadata(md)
165 .with_message("unsupported encoding")
166}
167
168#[allow(clippy::result_large_err)]
169pub(crate) fn get_incoming_encodings(
170 headers: &HeaderMap,
171 accept_compressed: &[CompressionEncoding],
172) -> Result<Option<CompressionEncoding>, Status> {
173 let Some(value) = headers.get("grpc-encoding") else {
174 return Ok(None);
175 };
176 let Some(encoding) = value
177 .to_str()
178 .ok()
179 .and_then(|value| value.parse::<CompressionEncoding>().ok())
180 else {
181 return Err(unimplemented(accept_compressed));
182 };
183 if !accept_compressed.contains(&encoding) {
184 return Err(unimplemented(accept_compressed));
185 }
186 Ok(Some(encoding))
187}