Skip to main content

avalanche_types/message/
accepted_frontier.rs

1use std::io::{self, Error, ErrorKind};
2
3use crate::{ids, message, proto::pb::p2p};
4use prost::bytes::Bytes;
5use prost::Message as ProstMessage;
6
7#[derive(Debug, PartialEq, Clone)]
8pub struct Message {
9    pub msg: p2p::AcceptedFrontier,
10    pub gzip_compress: bool,
11}
12
13impl Default for Message {
14    fn default() -> Self {
15        Message {
16            msg: p2p::AcceptedFrontier {
17                chain_id: Bytes::new(),
18                request_id: 0,
19                container_id: Bytes::new(),
20            },
21            gzip_compress: false,
22        }
23    }
24}
25
26impl Message {
27    #[must_use]
28    pub fn chain_id(mut self, chain_id: ids::Id) -> Self {
29        self.msg.chain_id = Bytes::from(chain_id.to_vec());
30        self
31    }
32
33    #[must_use]
34    pub fn request_id(mut self, request_id: u32) -> Self {
35        self.msg.request_id = request_id;
36        self
37    }
38
39    #[must_use]
40    pub fn container_id(mut self, id: ids::Id) -> Self {
41        self.msg.container_id = Bytes::from(id.to_vec());
42        self
43    }
44
45    #[must_use]
46    pub fn gzip_compress(mut self, gzip_compress: bool) -> Self {
47        self.gzip_compress = gzip_compress;
48        self
49    }
50
51    pub fn serialize(&self) -> io::Result<Vec<u8>> {
52        let msg = p2p::Message {
53            message: Some(p2p::message::Message::AcceptedFrontier(self.msg.clone())),
54        };
55        let encoded = ProstMessage::encode_to_vec(&msg);
56        if !self.gzip_compress {
57            return Ok(encoded);
58        }
59
60        let uncompressed_len = encoded.len();
61        let compressed = message::compress::pack_gzip(&encoded)?;
62        let msg = p2p::Message {
63            message: Some(p2p::message::Message::CompressedGzip(Bytes::from(
64                compressed,
65            ))),
66        };
67
68        let compressed_len = msg.encoded_len();
69        if uncompressed_len > compressed_len {
70            log::debug!(
71                "accepted_frontier compression saved {} bytes",
72                uncompressed_len - compressed_len
73            );
74        } else {
75            log::debug!(
76                "accepted_frontier compression added {} byte(s)",
77                compressed_len - uncompressed_len
78            );
79        }
80
81        Ok(ProstMessage::encode_to_vec(&msg))
82    }
83
84    pub fn deserialize(d: impl AsRef<[u8]>) -> io::Result<Self> {
85        let buf = bytes::Bytes::from(d.as_ref().to_vec());
86        let p2p_msg: p2p::Message = ProstMessage::decode(buf).map_err(|e| {
87            Error::new(
88                ErrorKind::InvalidData,
89                format!("failed prost::Message::decode '{}'", e),
90            )
91        })?;
92
93        match p2p_msg.message.unwrap() {
94            // was not compressed
95            p2p::message::Message::AcceptedFrontier(msg) => Ok(Message {
96                msg,
97                gzip_compress: false,
98            }),
99
100            // was compressed, so need decompress first
101            p2p::message::Message::CompressedGzip(msg) => {
102                let decompressed = message::compress::unpack_gzip(msg.as_ref())?;
103                let decompressed_msg: p2p::Message =
104                    ProstMessage::decode(Bytes::from(decompressed)).map_err(|e| {
105                        Error::new(
106                            ErrorKind::InvalidData,
107                            format!("failed prost::Message::decode '{}'", e),
108                        )
109                    })?;
110                match decompressed_msg.message.unwrap() {
111                    p2p::message::Message::AcceptedFrontier(msg) => Ok(Message {
112                        msg,
113                        gzip_compress: false,
114                    }),
115                    _ => Err(Error::new(
116                        ErrorKind::InvalidInput,
117                        "unknown message type after decompress",
118                    )),
119                }
120            }
121
122            // unknown message enum
123            _ => Err(Error::new(ErrorKind::InvalidInput, "unknown message type")),
124        }
125    }
126}
127
128/// RUST_LOG=debug cargo test --package avalanche-types --lib -- message::accepted_frontier::test_message --exact --show-output
129#[test]
130fn test_message() {
131    let _ = env_logger::builder()
132        .filter_level(log::LevelFilter::Debug)
133        .is_test(true)
134        .try_init();
135
136    let msg1_with_no_compression = Message::default()
137        .chain_id(ids::Id::from_slice(
138            &random_manager::secure_bytes(32).unwrap(),
139        ))
140        .request_id(random_manager::u32())
141        .container_id(ids::Id::from_slice(
142            &random_manager::secure_bytes(32).unwrap(),
143        ));
144
145    let data1 = msg1_with_no_compression.serialize().unwrap();
146    let msg1_with_no_compression_deserialized = Message::deserialize(data1).unwrap();
147    assert_eq!(
148        msg1_with_no_compression,
149        msg1_with_no_compression_deserialized
150    );
151
152    let msg2_with_compression = msg1_with_no_compression.clone().gzip_compress(true);
153    assert_ne!(msg1_with_no_compression, msg2_with_compression);
154
155    let data2 = msg2_with_compression.serialize().unwrap();
156    let msg2_with_compression_deserialized = Message::deserialize(data2).unwrap();
157    assert_eq!(msg1_with_no_compression, msg2_with_compression_deserialized);
158}