Skip to main content

avalanche_types/message/
chits.rs

1use std::io::{self, Error, ErrorKind};
2
3use crate::{ids, message, proto::pb::p2p};
4use prost::bytes::Bytes;
5use prost::Message as ProstMessage;
6
7#[derive(Debug, PartialEq, Clone)]
8pub struct Message {
9    pub msg: p2p::Chits,
10    pub gzip_compress: bool,
11}
12
13impl Default for Message {
14    fn default() -> Self {
15        Message {
16            msg: p2p::Chits {
17                chain_id: Bytes::new(),
18                request_id: 0,
19                preferred_id: Bytes::new(),
20                accepted_id: Bytes::new(),
21                preferred_id_at_height: Bytes::new(),
22            },
23            gzip_compress: false,
24        }
25    }
26}
27
28impl Message {
29    #[must_use]
30    pub fn chain_id(mut self, chain_id: ids::Id) -> Self {
31        self.msg.chain_id = Bytes::from(chain_id.to_vec());
32        self
33    }
34
35    #[must_use]
36    pub fn request_id(mut self, request_id: u32) -> Self {
37        self.msg.request_id = request_id;
38        self
39    }
40
41    #[must_use]
42    pub fn container_id(mut self, id: ids::Id) -> Self {
43        self.msg.preferred_id = Bytes::from(id.to_vec());
44        self
45    }
46
47    #[must_use]
48    pub fn gzip_compress(mut self, gzip_compress: bool) -> Self {
49        self.gzip_compress = gzip_compress;
50        self
51    }
52
53    pub fn serialize(&self) -> io::Result<Vec<u8>> {
54        let msg = p2p::Message {
55            message: Some(p2p::message::Message::Chits(self.msg.clone())),
56        };
57        let encoded = ProstMessage::encode_to_vec(&msg);
58        if !self.gzip_compress {
59            return Ok(encoded);
60        }
61
62        let uncompressed_len = encoded.len();
63        let compressed = message::compress::pack_gzip(&encoded)?;
64        let msg = p2p::Message {
65            message: Some(p2p::message::Message::CompressedGzip(Bytes::from(
66                compressed,
67            ))),
68        };
69
70        let compressed_len = msg.encoded_len();
71        if uncompressed_len > compressed_len {
72            log::debug!(
73                "chits compression saved {} bytes",
74                uncompressed_len - compressed_len
75            );
76        } else {
77            log::debug!(
78                "chits compression added {} byte(s)",
79                compressed_len - uncompressed_len
80            );
81        }
82
83        Ok(ProstMessage::encode_to_vec(&msg))
84    }
85
86    pub fn deserialize(d: impl AsRef<[u8]>) -> io::Result<Self> {
87        let buf = bytes::Bytes::from(d.as_ref().to_vec());
88        let p2p_msg: p2p::Message = ProstMessage::decode(buf).map_err(|e| {
89            Error::new(
90                ErrorKind::InvalidData,
91                format!("failed prost::Message::decode '{}'", e),
92            )
93        })?;
94
95        match p2p_msg.message.unwrap() {
96            // was not compressed
97            p2p::message::Message::Chits(msg) => Ok(Message {
98                msg,
99                gzip_compress: false,
100            }),
101
102            // was compressed, so need decompress first
103            p2p::message::Message::CompressedGzip(msg) => {
104                let decompressed = message::compress::unpack_gzip(msg.as_ref())?;
105                let decompressed_msg: p2p::Message =
106                    ProstMessage::decode(Bytes::from(decompressed)).map_err(|e| {
107                        Error::new(
108                            ErrorKind::InvalidData,
109                            format!("failed prost::Message::decode '{}'", e),
110                        )
111                    })?;
112                match decompressed_msg.message.unwrap() {
113                    p2p::message::Message::Chits(msg) => Ok(Message {
114                        msg,
115                        gzip_compress: false,
116                    }),
117                    _ => Err(Error::new(
118                        ErrorKind::InvalidInput,
119                        "unknown message type after decompress",
120                    )),
121                }
122            }
123
124            // unknown message enum
125            _ => Err(Error::new(ErrorKind::InvalidInput, "unknown message type")),
126        }
127    }
128}
129
130/// RUST_LOG=debug cargo test --package avalanche-types --lib -- message::chits::test_message --exact --show-output
131#[test]
132fn test_message() {
133    let _ = env_logger::builder()
134        .filter_level(log::LevelFilter::Debug)
135        .is_test(true)
136        .try_init();
137
138    let msg1_with_no_compression = Message::default()
139        .chain_id(ids::Id::from_slice(
140            &random_manager::secure_bytes(32).unwrap(),
141        ))
142        .request_id(random_manager::u32())
143        .container_id(ids::Id::from_slice(
144            &random_manager::secure_bytes(32).unwrap(),
145        ));
146
147    let data1 = msg1_with_no_compression.serialize().unwrap();
148    let msg1_with_no_compression_deserialized = Message::deserialize(data1).unwrap();
149    assert_eq!(
150        msg1_with_no_compression,
151        msg1_with_no_compression_deserialized
152    );
153
154    let msg2_with_compression = msg1_with_no_compression.clone().gzip_compress(true);
155    assert_ne!(msg1_with_no_compression, msg2_with_compression);
156
157    let data2 = msg2_with_compression.serialize().unwrap();
158    let msg2_with_compression_deserialized = Message::deserialize(data2).unwrap();
159    assert_eq!(msg1_with_no_compression, msg2_with_compression_deserialized);
160}