rocket_async_compression/
lib.rs1#[macro_use]
33extern crate log;
34
35mod fairing;
36mod responder;
37
38pub use self::{
39 fairing::{CachedCompression, Compression},
40 responder::Compress,
41};
42
43pub use async_compression::Level;
44use fairing::CachedEncoding;
45use rocket::{
46 http::{hyper::header::CONTENT_ENCODING, MediaType},
47 response::Body,
48 Request, Response,
49};
50
51pub enum Encoding {
52 Chunked,
54 Brotli,
56 Gzip,
58 Deflate,
60 Compress,
62 Identity,
64 Trailers,
66 EncodingExt(String),
68}
69
70impl std::fmt::Display for Encoding {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 f.write_str(match *self {
73 Encoding::Chunked => "chunked",
74 Encoding::Brotli => "br",
75 Encoding::Gzip => "gzip",
76 Encoding::Deflate => "deflate",
77 Encoding::Compress => "compress",
78 Encoding::Identity => "identity",
79 Encoding::Trailers => "trailers",
80 Encoding::EncodingExt(ref s) => s.as_ref(),
81 })
82 }
83}
84
85impl std::str::FromStr for Encoding {
86 type Err = std::convert::Infallible;
87
88 fn from_str(s: &str) -> Result<Encoding, std::convert::Infallible> {
89 match s {
90 "chunked" => Ok(Encoding::Chunked),
91 "br" => Ok(Encoding::Brotli),
92 "deflate" => Ok(Encoding::Deflate),
93 "gzip" => Ok(Encoding::Gzip),
94 "compress" => Ok(Encoding::Compress),
95 "identity" => Ok(Encoding::Identity),
96 "trailers" => Ok(Encoding::Trailers),
97 _ => Ok(Encoding::EncodingExt(s.to_owned())),
98 }
99 }
100}
101
102struct CompressionUtils;
103
104impl CompressionUtils {
105 fn already_encoded(response: &Response<'_>) -> bool {
106 response.headers().get("Content-Encoding").next().is_some()
107 }
108
109 fn set_body_and_encoding<'r, B: rocket::tokio::io::AsyncRead + Send + 'r>(
110 response: &'_ mut Response<'r>,
111 body: B,
112 encoding: Encoding,
113 ) {
114 response.set_header(::rocket::http::Header::new(
115 CONTENT_ENCODING.as_str(),
116 format!("{}", encoding),
117 ));
118 response.set_streamed_body(body);
119 }
120
121 fn skip_encoding(
122 content_type: &Option<rocket::http::ContentType>,
123 exclusions: &[MediaType],
124 ) -> bool {
125 match content_type {
126 Some(content_type) => exclusions.iter().any(|exc_media_type| {
127 if exc_media_type.sub() == "*" {
128 *exc_media_type.top() == *content_type.top()
129 } else {
130 *exc_media_type == *content_type.media_type()
131 }
132 }),
133 None => false,
134 }
135 }
136
137 fn accepted_algorithms(request: &Request<'_>) -> (bool, bool) {
139 request
140 .headers()
141 .get("Accept-Encoding")
142 .flat_map(|accept| accept.split(','))
143 .map(|accept| accept.trim())
144 .fold((false, false), |(accepts_gzip, accepts_br), encoding| {
145 (
146 accepts_gzip || encoding == "gzip",
147 accepts_br || encoding == "br",
148 )
149 })
150 }
151
152 async fn compress_body<'r>(
153 body: Body<'r>,
154 encoding: CachedEncoding,
155 level: async_compression::Level,
156 ) -> std::io::Result<Vec<u8>> {
157 match encoding {
158 CachedEncoding::Brotli => {
159 let level = match level {
164 async_compression::Level::Default => async_compression::Level::Precise(4),
165 other => other,
166 };
167
168 let mut compressor = async_compression::tokio::bufread::BrotliEncoder::with_quality(
169 rocket::tokio::io::BufReader::new(body),
170 level,
171 );
172 let mut out = Vec::new();
173 rocket::tokio::io::copy(&mut compressor, &mut out).await?;
174 Ok(out)
175 }
176 CachedEncoding::Gzip => {
177 let mut compressor = async_compression::tokio::bufread::GzipEncoder::with_quality(
178 rocket::tokio::io::BufReader::new(body),
179 level,
180 );
181 let mut out = Vec::new();
182 rocket::tokio::io::copy(&mut compressor, &mut out).await?;
183 Ok(out)
184 }
185 }
186 }
187
188 fn compress_response<'r>(
189 request: &Request<'_>,
190 response: &'_ mut Response<'r>,
191 exclusions: &[MediaType],
192 level: async_compression::Level,
193 ) {
194 if CompressionUtils::already_encoded(response) {
195 return;
196 }
197
198 let content_type = response.content_type();
199
200 if CompressionUtils::skip_encoding(&content_type, exclusions) {
201 return;
202 }
203
204 let (accepts_gzip, accepts_br) = Self::accepted_algorithms(request);
205
206 if !accepts_gzip && !accepts_br {
207 return;
208 }
209
210 let body = response.body_mut().take();
211
212 if accepts_br {
214 let compressor = async_compression::tokio::bufread::BrotliEncoder::with_quality(
215 rocket::tokio::io::BufReader::new(body),
216 level,
217 );
218
219 CompressionUtils::set_body_and_encoding(response, compressor, Encoding::Brotli);
220 } else if accepts_gzip {
221 let compressor = async_compression::tokio::bufread::GzipEncoder::with_quality(
222 rocket::tokio::io::BufReader::new(body),
223 level,
224 );
225
226 CompressionUtils::set_body_and_encoding(response, compressor, Encoding::Gzip);
227 }
228 }
229}