Skip to main content

avalanche_types/message/
accepted.rs

1use std::io::{self, Error, ErrorKind};
2
3use crate::{ids, message, proto::pb::p2p};
4use prost::Message as ProstMessage;
5
6#[derive(Debug, PartialEq, Clone)]
7pub struct Message {
8    pub msg: p2p::Accepted,
9    pub gzip_compress: bool,
10}
11
12impl Default for Message {
13    fn default() -> Self {
14        Message {
15            msg: p2p::Accepted {
16                chain_id: prost::bytes::Bytes::new(),
17                request_id: 0,
18                container_ids: Vec::new(),
19            },
20            gzip_compress: false,
21        }
22    }
23}
24
25impl Message {
26    #[must_use]
27    pub fn chain_id(mut self, chain_id: ids::Id) -> Self {
28        self.msg.chain_id = prost::bytes::Bytes::from(chain_id.to_vec());
29        self
30    }
31
32    #[must_use]
33    pub fn request_id(mut self, request_id: u32) -> Self {
34        self.msg.request_id = request_id;
35        self
36    }
37
38    #[must_use]
39    pub fn container_ids(mut self, container_ids: Vec<ids::Id>) -> Self {
40        let mut container_ids_bytes: Vec<prost::bytes::Bytes> =
41            Vec::with_capacity(container_ids.len());
42        for id in container_ids.iter() {
43            container_ids_bytes.push(prost::bytes::Bytes::from(id.to_vec()));
44        }
45        self.msg.container_ids = container_ids_bytes;
46        self
47    }
48
49    #[must_use]
50    pub fn gzip_compress(mut self, gzip_compress: bool) -> Self {
51        self.gzip_compress = gzip_compress;
52        self
53    }
54
55    pub fn serialize(&self) -> io::Result<Vec<u8>> {
56        let msg = p2p::Message {
57            message: Some(p2p::message::Message::Accepted(self.msg.clone())),
58        };
59        let encoded = ProstMessage::encode_to_vec(&msg);
60        if !self.gzip_compress {
61            return Ok(encoded);
62        }
63
64        let uncompressed_len = encoded.len();
65        let compressed = message::compress::pack_gzip(&encoded)?;
66        let msg = p2p::Message {
67            message: Some(p2p::message::Message::CompressedGzip(
68                prost::bytes::Bytes::from(compressed),
69            )),
70        };
71
72        let compressed_len = msg.encoded_len();
73        if uncompressed_len > compressed_len {
74            log::debug!(
75                "accepted compression saved {} bytes",
76                uncompressed_len - compressed_len
77            );
78        } else {
79            log::debug!(
80                "accepted compression added {} byte(s)",
81                compressed_len - uncompressed_len
82            );
83        }
84
85        Ok(ProstMessage::encode_to_vec(&msg))
86    }
87
88    pub fn deserialize(d: impl AsRef<[u8]>) -> io::Result<Self> {
89        let buf = bytes::Bytes::from(d.as_ref().to_vec());
90        let p2p_msg: p2p::Message = ProstMessage::decode(buf).map_err(|e| {
91            Error::new(
92                ErrorKind::InvalidData,
93                format!("failed prost::Message::decode '{}'", e),
94            )
95        })?;
96
97        match p2p_msg.message.unwrap() {
98            // was not compressed
99            p2p::message::Message::Accepted(msg) => Ok(Message {
100                msg,
101                gzip_compress: false,
102            }),
103
104            // was compressed, so need decompress first
105            p2p::message::Message::CompressedGzip(msg) => {
106                let decompressed = message::compress::unpack_gzip(msg.as_ref())?;
107                let decompressed_msg: p2p::Message =
108                    ProstMessage::decode(prost::bytes::Bytes::from(decompressed)).map_err(|e| {
109                        Error::new(
110                            ErrorKind::InvalidData,
111                            format!("failed prost::Message::decode '{}'", e),
112                        )
113                    })?;
114                match decompressed_msg.message.unwrap() {
115                    p2p::message::Message::Accepted(msg) => Ok(Message {
116                        msg,
117                        gzip_compress: false,
118                    }),
119                    _ => Err(Error::new(
120                        ErrorKind::InvalidInput,
121                        "unknown message type after decompress",
122                    )),
123                }
124            }
125
126            // unknown message enum
127            _ => Err(Error::new(ErrorKind::InvalidInput, "unknown message type")),
128        }
129    }
130}
131
132/// RUST_LOG=debug cargo test --package avalanche-types --lib -- message::accepted::test_message --exact --show-output
133#[test]
134fn test_message() {
135    let _ = env_logger::builder()
136        .filter_level(log::LevelFilter::Debug)
137        .is_test(true)
138        .try_init();
139
140    let msg1_with_no_compression = Message::default()
141        .chain_id(ids::Id::from_slice(
142            &random_manager::secure_bytes(32).unwrap(),
143        ))
144        .request_id(random_manager::u32())
145        .container_ids(vec![
146            ids::Id::empty(),
147            ids::Id::empty(),
148            ids::Id::empty(),
149            ids::Id::empty(),
150            ids::Id::empty(),
151            ids::Id::empty(),
152            ids::Id::empty(),
153            ids::Id::empty(),
154            ids::Id::empty(),
155            ids::Id::empty(),
156            ids::Id::from_slice(&random_manager::secure_bytes(32).unwrap()),
157            ids::Id::from_slice(&random_manager::secure_bytes(32).unwrap()),
158        ]);
159
160    let data1 = msg1_with_no_compression.serialize().unwrap();
161    let msg1_with_no_compression_deserialized = Message::deserialize(data1).unwrap();
162    assert_eq!(
163        msg1_with_no_compression,
164        msg1_with_no_compression_deserialized
165    );
166
167    let msg2_with_compression = msg1_with_no_compression.clone().gzip_compress(true);
168    assert_ne!(msg1_with_no_compression, msg2_with_compression);
169
170    let data2 = msg2_with_compression.serialize().unwrap();
171    let msg2_with_compression_deserialized = Message::deserialize(data2).unwrap();
172    assert_eq!(msg1_with_no_compression, msg2_with_compression_deserialized);
173}