avalanche_types/message/
ancestors.rs

1use std::io::{self, Error, ErrorKind};
2
3use crate::{ids, message, proto::pb::p2p};
4use prost::Message as ProstMessage;
5
6#[derive(Debug, PartialEq, Clone)]
7pub struct Message {
8    pub msg: p2p::Ancestors,
9    pub gzip_compress: bool,
10}
11
12impl Default for Message {
13    fn default() -> Self {
14        Message {
15            msg: p2p::Ancestors {
16                chain_id: prost::bytes::Bytes::new(),
17                request_id: 0,
18                containers: Vec::new(),
19            },
20            gzip_compress: false,
21        }
22    }
23}
24
25impl Message {
26    #[must_use]
27    pub fn chain_id(mut self, chain_id: ids::Id) -> Self {
28        self.msg.chain_id = prost::bytes::Bytes::from(chain_id.to_vec());
29        self
30    }
31
32    #[must_use]
33    pub fn request_id(mut self, request_id: u32) -> Self {
34        self.msg.request_id = request_id;
35        self
36    }
37
38    #[must_use]
39    pub fn containers(mut self, containers: Vec<Vec<u8>>) -> Self {
40        let mut containers_bytes: Vec<prost::bytes::Bytes> = Vec::with_capacity(containers.len());
41        for b in containers {
42            containers_bytes.push(prost::bytes::Bytes::from(b));
43        }
44        self.msg.containers = containers_bytes;
45        self
46    }
47
48    #[must_use]
49    pub fn gzip_compress(mut self, gzip_compress: bool) -> Self {
50        self.gzip_compress = gzip_compress;
51        self
52    }
53
54    pub fn serialize(&self) -> io::Result<Vec<u8>> {
55        let msg = p2p::Message {
56            message: Some(p2p::message::Message::Ancestors(self.msg.clone())),
57        };
58        let encoded = ProstMessage::encode_to_vec(&msg);
59        if !self.gzip_compress {
60            return Ok(encoded);
61        }
62
63        let uncompressed_len = encoded.len();
64        let compressed = message::compress::pack_gzip(&encoded)?;
65        let msg = p2p::Message {
66            message: Some(p2p::message::Message::CompressedGzip(
67                prost::bytes::Bytes::from(compressed),
68            )),
69        };
70
71        let compressed_len = msg.encoded_len();
72        if uncompressed_len > compressed_len {
73            log::debug!(
74                "ancestors compression saved {} bytes",
75                uncompressed_len - compressed_len
76            );
77        } else {
78            log::debug!(
79                "ancestors compression added {} byte(s)",
80                compressed_len - uncompressed_len
81            );
82        }
83
84        Ok(ProstMessage::encode_to_vec(&msg))
85    }
86
87    pub fn deserialize(d: impl AsRef<[u8]>) -> io::Result<Self> {
88        let buf = bytes::Bytes::from(d.as_ref().to_vec());
89        let p2p_msg: p2p::Message = ProstMessage::decode(buf).map_err(|e| {
90            Error::new(
91                ErrorKind::InvalidData,
92                format!("failed prost::Message::decode '{}'", e),
93            )
94        })?;
95
96        match p2p_msg.message.unwrap() {
97            // was not compressed
98            p2p::message::Message::Ancestors(msg) => Ok(Message {
99                msg,
100                gzip_compress: false,
101            }),
102
103            // was compressed, so need decompress first
104            p2p::message::Message::CompressedGzip(msg) => {
105                let decompressed = message::compress::unpack_gzip(msg.as_ref())?;
106                let decompressed_msg: p2p::Message =
107                    ProstMessage::decode(prost::bytes::Bytes::from(decompressed)).map_err(|e| {
108                        Error::new(
109                            ErrorKind::InvalidData,
110                            format!("failed prost::Message::decode '{}'", e),
111                        )
112                    })?;
113                match decompressed_msg.message.unwrap() {
114                    p2p::message::Message::Ancestors(msg) => Ok(Message {
115                        msg,
116                        gzip_compress: false,
117                    }),
118                    _ => Err(Error::new(
119                        ErrorKind::InvalidInput,
120                        "unknown message type after decompress",
121                    )),
122                }
123            }
124
125            // unknown message enum
126            _ => Err(Error::new(ErrorKind::InvalidInput, "unknown message type")),
127        }
128    }
129}
130
131/// RUST_LOG=debug cargo test --package avalanche-types --lib -- message::ancestors::test_message --exact --show-output
132#[test]
133fn test_message() {
134    let _ = env_logger::builder()
135        .filter_level(log::LevelFilter::Debug)
136        .is_test(true)
137        .try_init();
138
139    let msg1_with_no_compression = Message::default()
140        .chain_id(ids::Id::from_slice(
141            &random_manager::secure_bytes(32).unwrap(),
142        ))
143        .request_id(random_manager::u32())
144        .containers(vec![
145            random_manager::secure_bytes(30).unwrap(),
146            random_manager::secure_bytes(30).unwrap(),
147            random_manager::secure_bytes(30).unwrap(),
148            random_manager::secure_bytes(30).unwrap(),
149            random_manager::secure_bytes(30).unwrap(),
150            random_manager::secure_bytes(30).unwrap(),
151            random_manager::secure_bytes(30).unwrap(),
152        ]);
153
154    let data1 = msg1_with_no_compression.serialize().unwrap();
155    let msg1_with_no_compression_deserialized = Message::deserialize(data1).unwrap();
156    assert_eq!(
157        msg1_with_no_compression,
158        msg1_with_no_compression_deserialized
159    );
160
161    let msg2_with_compression = msg1_with_no_compression.clone().gzip_compress(true);
162    assert_ne!(msg1_with_no_compression, msg2_with_compression);
163
164    let data2 = msg2_with_compression.serialize().unwrap();
165    let msg2_with_compression_deserialized = Message::deserialize(data2).unwrap();
166    assert_eq!(msg1_with_no_compression, msg2_with_compression_deserialized);
167}