ckb_network/
compress.rs

1//!ckb network compress module
2
3use ckb_logger::debug;
4use p2p::bytes::{BufMut, Bytes, BytesMut};
5use snap::raw::{Decoder as SnapDecoder, Encoder as SnapEncoder, decompress_len};
6use tokio_util::codec::length_delimited;
7
8use std::io;
9
10pub(crate) const COMPRESSION_SIZE_THRESHOLD: usize = 1024;
11pub(crate) const UNCOMPRESS_FLAG: u8 = 0b0000_0000;
12const COMPRESS_FLAG: u8 = 0b1000_0000;
13const MAX_UNCOMPRESSED_LEN: usize = 1 << 23; // 8MB
14
15/// Compressed decompression structure
16///
17/// If you want to support multiple compression formats in the future,
18/// you can simply think that 0b1000 is in snappy format and 0b0000 is in uncompressed format.
19///
20/// # Message in Bytes:
21///
22/// +---------------------------------------------------------------+
23/// | Bytes | Type | Function                                       |
24/// |-------+------+------------------------------------------------|
25/// |   0   |  u1  | Compress: true 1, false 0                      |
26/// |       |  u7  | Reserved                                       |
27/// +-------+------+------------------------------------------------+
28/// |  1~   |      | Payload (Serialized Data with Compress)        |
29/// +-------+------+------------------------------------------------+
30#[derive(Clone, Debug)]
31pub(crate) struct Message {
32    inner: BytesMut,
33}
34
35impl Message {
36    /// create from uncompressed raw data
37    pub(crate) fn from_raw(data: Bytes) -> Self {
38        let mut inner = BytesMut::with_capacity(data.len() + 1);
39        inner.put_u8(UNCOMPRESS_FLAG);
40        inner.put(data);
41        Self { inner }
42    }
43
44    /// create from compressed data
45    pub(crate) fn from_compressed(data: BytesMut) -> Self {
46        Self { inner: data }
47    }
48
49    /// Compress message
50    pub(crate) fn compress(mut self) -> Bytes {
51        if self.inner.len() > COMPRESSION_SIZE_THRESHOLD {
52            let input = self.inner.split_off(1);
53            match SnapEncoder::new().compress_vec(&input) {
54                Ok(res) => {
55                    self.inner.extend_from_slice(&res);
56                    self.set_compress_flag();
57                }
58                Err(e) => {
59                    debug!("snappy compress error: {}", e);
60                    self.inner.unsplit(input);
61                }
62            }
63        }
64        self.inner.freeze()
65    }
66
67    /// Decompress message
68    pub(crate) fn decompress(mut self) -> Result<Bytes, io::Error> {
69        if self.inner.is_empty() {
70            Err(io::ErrorKind::InvalidData.into())
71        } else if self.compress_flag() {
72            match decompress_len(&self.inner[1..]) {
73                Ok(decompressed_bytes_len) => {
74                    if decompressed_bytes_len > MAX_UNCOMPRESSED_LEN {
75                        debug!(
76                            "The limit for uncompressed bytes len is exceeded. limit: {}, len: {}",
77                            MAX_UNCOMPRESSED_LEN, decompressed_bytes_len
78                        );
79                        Err(io::ErrorKind::InvalidData.into())
80                    } else {
81                        let mut buf = vec![0; decompressed_bytes_len];
82                        match SnapDecoder::new().decompress(&self.inner[1..], &mut buf) {
83                            Ok(_) => Ok(buf.into()),
84                            Err(e) => {
85                                debug!("snappy decompress error: {:?}", e);
86                                Err(io::ErrorKind::InvalidData.into())
87                            }
88                        }
89                    }
90                }
91                Err(e) => {
92                    debug!("snappy decompress_len error: {:?}", e);
93                    Err(io::ErrorKind::InvalidData.into())
94                }
95            }
96        } else {
97            let _ = self.inner.split_to(1);
98            Ok(self.inner.freeze())
99        }
100    }
101
102    pub(crate) fn set_compress_flag(&mut self) {
103        self.inner[0] = COMPRESS_FLAG;
104    }
105
106    pub(crate) fn compress_flag(&self) -> bool {
107        (self.inner[0] & COMPRESS_FLAG) != 0
108    }
109}
110
111/// Compress data
112pub fn compress(src: Bytes) -> Bytes {
113    Message::from_raw(src).compress()
114}
115
116/// Decompress data
117pub fn decompress(src: BytesMut) -> Result<Bytes, io::Error> {
118    Message::from_compressed(src).decompress()
119}
120
121/// LengthDelimitedCodec with compress support
122pub struct LengthDelimitedCodecWithCompress {
123    length_delimited: length_delimited::LengthDelimitedCodec,
124    enable_compress: bool,
125    protocol_id: p2p::ProtocolId,
126}
127
128impl LengthDelimitedCodecWithCompress {
129    /// Create a new LengthDelimitedCodecWithCompress
130    pub fn new(
131        enable_compress: bool,
132        length_delimited: length_delimited::LengthDelimitedCodec,
133        protocol_id: p2p::ProtocolId,
134    ) -> Self {
135        Self {
136            length_delimited,
137            enable_compress,
138            protocol_id,
139        }
140    }
141
142    fn process(&self, data: &[u8], flag: u8, dst: &mut BytesMut) -> Result<(), io::Error> {
143        let len = data.len() + 1;
144        if len > self.length_delimited.max_frame_length() {
145            return Err(io::Error::new(
146                io::ErrorKind::InvalidInput,
147                "data too large",
148            ));
149        }
150        dst.reserve(4 + len);
151        dst.put_uint(len as u64, 4);
152        dst.put_u8(flag);
153        dst.extend_from_slice(data);
154        Ok(())
155    }
156}
157
158impl tokio_util::codec::Encoder<Bytes> for LengthDelimitedCodecWithCompress {
159    type Error = io::Error;
160    fn encode(&mut self, data: Bytes, dst: &mut BytesMut) -> Result<(), io::Error> {
161        if self.enable_compress && data.len() > COMPRESSION_SIZE_THRESHOLD {
162            match SnapEncoder::new().compress_vec(&data) {
163                Ok(res) => {
164                    debug!(
165                        "protocol {} message snappy compress result: raw: {}, compressed: {}, ratio: {:.2}%",
166                        self.protocol_id,
167                        data.len(),
168                        res.len(),
169                        (res.len() as f64 / data.len() as f64 * 100.0)
170                    );
171                    if let Some(metrics) = ckb_metrics::handle() {
172                        metrics
173                            .ckb_network_compress
174                            .with_label_values(&[
175                                self.protocol_id.to_string().as_str(),
176                                "succeeded",
177                                "compressed ratio",
178                            ])
179                            .observe(res.len() as f64 / data.len() as f64);
180                    }
181                    // compressed data is larger than or equal to uncompressed data
182                    if res.len() >= data.len() {
183                        self.process(&data, UNCOMPRESS_FLAG, dst)?;
184                    } else {
185                        self.process(&res, COMPRESS_FLAG, dst)?;
186                    }
187                }
188                Err(e) => {
189                    debug!(
190                        "protocol {} message snappy compress error: {}",
191                        self.protocol_id, e
192                    );
193                    if let Some(metrics) = ckb_metrics::handle() {
194                        metrics
195                            .ckb_network_compress
196                            .with_label_values(&[
197                                self.protocol_id.to_string().as_str(),
198                                "failed",
199                                "compressed ratio",
200                            ])
201                            .observe(1.0);
202                    }
203                    self.process(&data, UNCOMPRESS_FLAG, dst)?;
204                }
205            }
206        } else {
207            if let Some(metrics) = ckb_metrics::handle() {
208                metrics
209                    .ckb_network_not_compress_count
210                    .with_label_values(&[self.protocol_id.to_string().as_str()])
211                    .inc();
212            }
213            self.process(&data, UNCOMPRESS_FLAG, dst)?;
214        }
215        Ok(())
216    }
217}
218
219impl tokio_util::codec::Decoder for LengthDelimitedCodecWithCompress {
220    type Item = BytesMut;
221    type Error = io::Error;
222    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<BytesMut>, io::Error> {
223        if src.is_empty() {
224            return Ok(None);
225        }
226        match self.length_delimited.decode(src)? {
227            Some(mut data) => {
228                if data.len() < 2 {
229                    return Err(io::ErrorKind::InvalidData.into());
230                }
231
232                if (data[0] & COMPRESS_FLAG) != 0 {
233                    match decompress_len(&data[1..]) {
234                        Ok(decompressed_bytes_len) => {
235                            if decompressed_bytes_len > MAX_UNCOMPRESSED_LEN {
236                                debug!(
237                                    "The limit for uncompressed bytes len is exceeded. limit: {}, len: {}",
238                                    MAX_UNCOMPRESSED_LEN, decompressed_bytes_len
239                                );
240                                return Err(io::ErrorKind::InvalidData.into());
241                            }
242                            let mut buf = BytesMut::zeroed(decompressed_bytes_len);
243                            match SnapDecoder::new().decompress(&data[1..], &mut buf) {
244                                Ok(_) => Ok(Some(buf)),
245                                Err(e) => {
246                                    debug!("snappy decompress error: {:?}", e);
247                                    Err(io::ErrorKind::InvalidData.into())
248                                }
249                            }
250                        }
251                        Err(e) => {
252                            debug!("snappy decompress_len error: {:?}", e);
253                            Err(io::ErrorKind::InvalidData.into())
254                        }
255                    }
256                } else {
257                    Ok(Some(data.split_off(1)))
258                }
259            }
260            None => Ok(None),
261        }
262    }
263}