1use bytes::Bytes;
10use core::fmt;
11use flate2::{
12 read::{DeflateDecoder, DeflateEncoder, GzDecoder, GzEncoder, ZlibDecoder},
13 Compression,
14};
15use http::{
16 header::{ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE},
17 HeaderValue, Request, Response, StatusCode,
18};
19use http_body_util::{combinators::BoxBody, BodyExt, Full};
20use hyper::{body::Incoming, service::Service};
21use std::{fmt::Debug, future::Future, pin::Pin};
22use std::{io::prelude::*, str::FromStr};
23
24type Result<T> = std::result::Result<T, HyperContentEncodingError>;
25
26#[derive(Debug, Clone)]
28pub struct HyperContentEncodingError {
29 inner: String,
30}
31
32impl std::error::Error for HyperContentEncodingError {}
33
34impl HyperContentEncodingError {
35 pub fn new(inner: String) -> Self {
36 HyperContentEncodingError { inner }
37 }
38}
39
40impl From<String> for HyperContentEncodingError {
41 fn from(value: String) -> Self {
42 HyperContentEncodingError::new(value)
43 }
44}
45
46impl fmt::Display for HyperContentEncodingError {
47 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
48 write!(f, "{}", self.inner)
49 }
50}
51
52fn convert_err<E>(e: E) -> HyperContentEncodingError
53where
54 E: Debug,
55{
56 format!("{:?}", e).into()
57}
58
59trait Decoder<T>: Read
60where
61 T: Read,
62{
63 fn new(_: T) -> Self;
64}
65
66impl<R> Decoder<R> for GzDecoder<R>
67where
68 R: Read,
69{
70 fn new(reader: R) -> Self {
71 GzDecoder::new(reader)
72 }
73}
74
75impl<R> Decoder<R> for DeflateDecoder<R>
76where
77 R: Read,
78{
79 fn new(reader: R) -> Self {
80 DeflateDecoder::new(reader)
81 }
82}
83
84impl<R> Decoder<R> for ZlibDecoder<R>
85where
86 R: Read,
87{
88 fn new(reader: R) -> Self {
89 ZlibDecoder::new(reader)
90 }
91}
92
93fn decompress<'a, T>(body: &'a Bytes) -> Result<String>
94where
95 T: Decoder<&'a [u8]>,
96{
97 let reader: &[u8] = body;
98 let mut decoder = T::new(reader);
99 let mut s = String::new();
100 decoder.read_to_string(&mut s).map_err(convert_err)?;
101 Ok(s)
102}
103
104pub async fn response_to_string(res: Response<Incoming>) -> Result<String> {
114 if let Some(content_type) = res.headers().get(CONTENT_TYPE) {
115 if content_type.to_str().map_err(convert_err)?.contains("text") {
116 if let Some(encoding) = Encoding::get_response_encoding(&res) {
117 let res = res.map(|b| b.boxed());
118 let body: Bytes = res.collect().await.map_err(convert_err)?.to_bytes();
119
120 match encoding {
121 Encoding::Gzip => decompress::<GzDecoder<_>>(&body),
122 Encoding::Deflate => decompress::<DeflateDecoder<_>>(&body),
123 Encoding::Identity => {
124 let body = String::from_utf8_lossy(&body).to_string();
125 Ok(body)
126 }
127 }
128 } else {
129 Err(format!("Unknown Content-Type").into())
130 }
131 } else {
132 Err("Content-Type does not specify text".to_string().into())
133 }
134 } else {
135 Err("No Content-Type specified".to_string().into())
136 }
137}
138
139#[derive(Debug, Clone, PartialEq, Eq)]
141pub enum Encoding {
142 Gzip,
144
145 Deflate,
147
148 Identity,
150}
151
152impl Encoding {
153 fn as_header_value(&self) -> HeaderValue {
154 match &self {
155 Encoding::Gzip => HeaderValue::from_static("gzip"),
156
157 Encoding::Deflate => HeaderValue::from_static("deflate"),
158
159 Encoding::Identity => HeaderValue::from_static("identity"),
160 }
161 }
162
163 pub fn get_response_encoding(res: &Response<Incoming>) -> Option<Self> {
166 if let Some(content_encoding) = res
167 .headers()
168 .get(CONTENT_ENCODING)
169 .and_then(|v| v.to_str().ok())
170 {
171 match content_encoding {
172 "gzip" | "x-gzip" => Some(Encoding::Gzip),
173 "deflate" => Some(Encoding::Gzip),
174 "identitiy" => Some(Encoding::Identity),
175 _ => None,
176 }
177 } else {
178 Some(Encoding::Identity)
179 }
180 }
181}
182
183pub fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
185 Full::new(chunk.into())
186 .map_err(|never| match never {})
187 .boxed()
188}
189
190impl FromStr for Encoding {
191 type Err = HyperContentEncodingError;
192 fn from_str(value: &str) -> std::result::Result<Self, Self::Err> {
193 match value {
194 "x-gzip" | "gzip" => Ok(Encoding::Gzip),
195 "deflate" => Ok(Encoding::Deflate),
196 "identity" => Ok(Encoding::Identity),
197 _ => Err(format!("Unrecognized encoding {}", value).into()),
198 }
199 }
200}
201
202pub async fn encode_response(res: Res, content_encoding: Encoding) -> Result<Res> {
210 let headers = res.headers().clone();
211 let status = res.status();
212
213 let res = res.map(|b| b.boxed());
214
215 let body: Bytes = res
216 .into_body()
217 .collect()
218 .await
219 .map_err(convert_err)?
220 .to_bytes();
221
222 let mut ret_vec: Vec<u8> = Vec::new();
223 match content_encoding {
224 Encoding::Gzip => {
225 let body: &[u8] = &body;
226 GzEncoder::new(body, Compression::fast())
227 .read_to_end(&mut ret_vec)
228 .map_err(convert_err)
229 }
230
231 Encoding::Deflate => {
232 let body: &[u8] = &body;
233 DeflateEncoder::new(body, Compression::fast())
234 .read_to_end(&mut ret_vec)
235 .map_err(convert_err)
236 }
237
238 Encoding::Identity => {
239 ret_vec = body.into();
240 Ok(ret_vec.len())
241 }
242 }?;
243
244 let body: Bytes = ret_vec.into();
245 let body_len = body.len();
246
247 let mut res = Response::new(full(body));
248 *res.headers_mut() = headers;
249 *res.status_mut() = status;
250
251 res.headers_mut()
252 .insert(CONTENT_ENCODING, content_encoding.as_header_value());
253
254 res.headers_mut().insert(
255 CONTENT_LENGTH,
256 body_len
257 .to_string()
258 .parse()
259 .expect("Unexpected Content-Length"),
260 );
261
262 Ok(res)
263}
264
265#[derive(Debug, Clone)]
267pub struct Compressor<S> {
268 inner: S,
269}
270
271impl<S> Compressor<S> {
272 pub fn new(inner: S) -> Self {
274 Compressor { inner }
275 }
276}
277
278fn parse_weight(input: &str) -> Option<f32> {
285 let mut chars = input.chars().peekable();
286
287 while let Some(ch) = chars.peek() {
289 if !ch.is_whitespace() {
290 break;
291 }
292
293 chars.next();
294 }
295
296 if chars.next() != Some(';') {
298 return None;
299 }
300
301 while let Some(ch) = chars.peek() {
303 if !ch.is_whitespace() {
304 break;
305 }
306
307 chars.next();
308 }
309
310 if chars.next() != Some('q') || chars.next() != Some('=') {
312 return None;
313 }
314
315 let mut qvalue_str = String::new();
317 for ch in &mut chars {
318 if !ch.is_ascii_digit() && ch != '.' {
319 break;
320 }
321 qvalue_str.push(ch);
322 }
323
324 let qvalue: f32 = qvalue_str.parse().ok()?;
326
327 if (0.0..1.0).contains(&qvalue) {
328 Some(qvalue)
329 } else {
330 eprintln!("Q={}", qvalue);
331 None
332 }
333}
334
335fn parse_encoding(accepted_encodings: &str) -> Vec<(Encoding, f32)> {
337 let mut accepted_encodings = accepted_encodings.trim();
338
339 let mut res = Vec::new();
340
341 let mut default_weight: Option<f32> = None;
342
343 loop {
344 for token in ["gzip", "deflate", "identity", "*"] {
345 if accepted_encodings.starts_with(token) {
346 (_, accepted_encodings) = accepted_encodings.split_at(token.len());
347
348 let mut weight: f32 = 1.0;
349 if let Some(res) = parse_weight(accepted_encodings) {
350 weight = res;
351 }
352
353 if token == "*" {
354 default_weight = Some(weight);
355 } else {
356 res.push((Encoding::from_str(token).unwrap(), weight));
357 }
358
359 break;
360 }
361 }
362
363 if let Some(index) = accepted_encodings.find(',') {
364 (_, accepted_encodings) = accepted_encodings.split_at(index);
365 let mut chars = accepted_encodings.chars();
366 chars.next();
367 accepted_encodings = chars.as_str().trim();
368 } else {
369 break;
370 }
371 }
372
373 if let Some(weigth) = default_weight {
374 for encoding in [Encoding::Gzip, Encoding::Deflate, Encoding::Identity] {
375 if !res.iter().any(|(x, _)| *x == encoding) {
376 res.push((encoding, weigth));
377 }
378 }
379 } else if !res.iter().any(|(x, _)| *x == Encoding::Identity) {
380 res.push((Encoding::Identity, 1.0));
381 }
382
383 res
384
385 }
387
388fn prefered_encoding(accepted_encodings: &str) -> Option<Encoding> {
389 let mut encodings = parse_encoding(accepted_encodings);
390 encodings.sort_by_key(|&(_, w)| -(w * 1000.0) as i32);
391
392 encodings
393 .iter()
394 .find(|(_, w)| *w > 0.0)
395 .map(|(e, _)| e.to_owned())
396}
397
398type Req = Request<Incoming>;
399type Res = Response<BoxBody<Bytes, hyper::Error>>;
400
401impl<S> Service<Req> for Compressor<S>
402where
403 S: Service<Req, Response = Res>,
404 S::Future: 'static + Send,
405 S::Error: 'static + Send,
406 S::Error: Debug,
407 S::Response: 'static,
408{
409 type Response = Res;
410 type Error = Box<HyperContentEncodingError>;
411 type Future =
412 Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
413
414 fn call(&self, req: Req) -> Self::Future {
415 let headers = req.headers().clone();
416
417 let encoding = if let Some(accepted_encodings) = headers.get(ACCEPT_ENCODING) {
419 if let Some(desired_encoding) =
420 accepted_encodings.to_str().ok().and_then(prefered_encoding)
421 {
422 desired_encoding
423 } else {
424 return Box::pin(async move {
425 let mut res = Response::new(full("Unsuported requestedd encoding\n"));
426 *res.status_mut() = StatusCode::NOT_ACCEPTABLE;
427 Ok(res)
428 });
429 }
430 } else {
431 Encoding::Gzip
432 };
433
434 let fut = self.inner.call(req);
435
436 let f = async move {
437 match fut.await {
438 Ok(response) => encode_response(response, encoding).await.map_err(Box::new),
439 Err(e) => Err(Box::new(convert_err(e))),
440 }
441 };
442
443 Box::pin(f)
444 }
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450
451 #[test]
452 fn no_weight() {
453 let result = parse_encoding("compress, gzip");
454 assert_eq!(
455 vec![(Encoding::Gzip, 1.0), (Encoding::Identity, 1.0)],
456 result
457 );
458 }
459
460 #[test]
461 fn empty() {
462 let result = parse_encoding("");
463 assert_eq!(vec![(Encoding::Identity, 1.0)], result);
464 }
465
466 #[test]
467 fn star() {
468 let result = parse_encoding("*");
469 assert_eq!(
470 vec![
471 (Encoding::Gzip, 1.0),
472 (Encoding::Deflate, 1.0),
473 (Encoding::Identity, 1.0)
474 ],
475 result
476 );
477 }
478
479 #[test]
480 fn weigth() {
481 let result = parse_encoding("deflate;q=0.5, gzip;q=1.0");
482 eprintln!("{:?}", result);
483 assert_eq!(
484 vec![
485 (Encoding::Deflate, 0.5),
486 (Encoding::Gzip, 1.0),
487 (Encoding::Identity, 1.0)
488 ],
489 result
490 );
491 }
492
493 #[test]
494 fn no_identity() {
495 let result = parse_encoding("gzip;q=1.0, deflate; q=0.5, *;q=0");
496 eprintln!("{:?}", result);
497 assert_eq!(
498 vec![
499 (Encoding::Gzip, 1.0),
500 (Encoding::Deflate, 0.5),
501 (Encoding::Identity, 0.0)
502 ],
503 result
504 );
505 }
506}