dubbo/triple/
compression.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18use std::collections::HashMap;
19
20use bytes::{Buf, BufMut, BytesMut};
21use flate2::{
22    read::{GzDecoder, GzEncoder},
23    Compression,
24};
25use lazy_static::lazy_static;
26
27pub const GRPC_ACCEPT_ENCODING: &str = "grpc-accept-encoding";
28pub const GRPC_ENCODING: &str = "grpc-encoding";
29
30#[derive(Debug, Clone, Copy)]
31pub enum CompressionEncoding {
32    Gzip,
33}
34
35lazy_static! {
36    pub static ref COMPRESSIONS: HashMap<String, Option<CompressionEncoding>> = {
37        let mut v = HashMap::new();
38        v.insert("gzip".to_string(), Some(CompressionEncoding::Gzip));
39        v
40    };
41}
42
43impl CompressionEncoding {
44    pub fn from_accept_encoding(header: &http::HeaderMap) -> Option<CompressionEncoding> {
45        let accept_encoding = header.get(GRPC_ACCEPT_ENCODING)?;
46        let encodings = accept_encoding.to_str().ok()?;
47
48        encodings
49            .trim()
50            .split(',')
51            .map(|s| s.trim())
52            .into_iter()
53            .find_map(|s| match s {
54                "gzip" => Some(CompressionEncoding::Gzip),
55                _ => None,
56            })
57    }
58
59    pub fn into_header_value(self) -> http::HeaderValue {
60        match self {
61            CompressionEncoding::Gzip => http::HeaderValue::from_static("gzip"),
62        }
63    }
64}
65
66pub fn compress(
67    encoding: CompressionEncoding,
68    src: &mut BytesMut,
69    dst: &mut BytesMut,
70    len: usize,
71) -> Result<(), std::io::Error> {
72    dst.reserve(len);
73
74    match encoding {
75        CompressionEncoding::Gzip => {
76            let mut en = GzEncoder::new(src.reader(), Compression::default());
77
78            let mut dst_writer = dst.writer();
79
80            std::io::copy(&mut en, &mut dst_writer)?;
81        }
82    }
83
84    Ok(())
85}
86
87pub fn decompress(
88    encoding: CompressionEncoding,
89    src: &mut BytesMut,
90    dst: &mut BytesMut,
91    len: usize,
92) -> Result<(), std::io::Error> {
93    let capacity = len * 2;
94    dst.reserve(capacity);
95
96    match encoding {
97        CompressionEncoding::Gzip => {
98            let mut de = GzDecoder::new(src.reader());
99
100            let mut dst_writer = dst.writer();
101
102            std::io::copy(&mut de, &mut dst_writer)?;
103        }
104    }
105    Ok(())
106}
107
108#[test]
109fn test_compress() {
110    let mut src = BytesMut::with_capacity(super::consts::BUFFER_SIZE);
111    src.put(&b"test compress"[..]);
112    let mut dst = BytesMut::new();
113    let len = src.len();
114    src.reserve(len);
115
116    compress(CompressionEncoding::Gzip, &mut src, &mut dst, len).unwrap();
117    println!("src: {:?}, dst: {:?}", src, dst);
118
119    let mut de_dst = BytesMut::with_capacity(super::consts::BUFFER_SIZE);
120    let de_len = dst.len();
121    decompress(CompressionEncoding::Gzip, &mut dst, &mut de_dst, de_len).unwrap();
122
123    println!("src: {:?}, dst: {:?}", dst, de_dst);
124}