Skip to main content

avalanche_types/message/
pull_query.rs

1use std::io::{self, Error, ErrorKind};
2
3use crate::{ids, message, proto::pb::p2p};
4use prost::Message as ProstMessage;
5
6#[derive(Debug, PartialEq, Clone)]
7pub struct Message {
8    pub msg: p2p::PullQuery,
9    pub gzip_compress: bool,
10}
11
12impl Default for Message {
13    fn default() -> Self {
14        Message {
15            msg: p2p::PullQuery {
16                chain_id: prost::bytes::Bytes::new(),
17                request_id: 0,
18                deadline: 0,
19                container_id: prost::bytes::Bytes::new(),
20                engine_type: p2p::EngineType::Unspecified.into(),
21                requested_height: 0,
22            },
23            gzip_compress: false,
24        }
25    }
26}
27
28impl Message {
29    #[must_use]
30    pub fn chain_id(mut self, chain_id: ids::Id) -> Self {
31        self.msg.chain_id = prost::bytes::Bytes::from(chain_id.to_vec());
32        self
33    }
34
35    #[must_use]
36    pub fn request_id(mut self, request_id: u32) -> Self {
37        self.msg.request_id = request_id;
38        self
39    }
40
41    #[must_use]
42    pub fn deadline(mut self, deadline: u64) -> Self {
43        self.msg.deadline = deadline;
44        self
45    }
46
47    #[must_use]
48    pub fn container_id(mut self, container_id: ids::Id) -> Self {
49        self.msg.container_id = prost::bytes::Bytes::from(container_id.to_vec());
50        self
51    }
52
53    #[must_use]
54    pub fn gzip_compress(mut self, gzip_compress: bool) -> Self {
55        self.gzip_compress = gzip_compress;
56        self
57    }
58
59    pub fn serialize(&self) -> io::Result<Vec<u8>> {
60        let msg = p2p::Message {
61            message: Some(p2p::message::Message::PullQuery(self.msg.clone())),
62        };
63        let encoded = ProstMessage::encode_to_vec(&msg);
64        if !self.gzip_compress {
65            return Ok(encoded);
66        }
67
68        let uncompressed_len = encoded.len();
69        let compressed = message::compress::pack_gzip(&encoded)?;
70        let msg = p2p::Message {
71            message: Some(p2p::message::Message::CompressedGzip(
72                prost::bytes::Bytes::from(compressed),
73            )),
74        };
75
76        let compressed_len = msg.encoded_len();
77        if uncompressed_len > compressed_len {
78            log::debug!(
79                "pull_query compression saved {} bytes",
80                uncompressed_len - compressed_len
81            );
82        } else {
83            log::debug!(
84                "pull_query compression added {} byte(s)",
85                compressed_len - uncompressed_len
86            );
87        }
88
89        Ok(ProstMessage::encode_to_vec(&msg))
90    }
91
92    pub fn deserialize(d: impl AsRef<[u8]>) -> io::Result<Self> {
93        let buf = bytes::Bytes::from(d.as_ref().to_vec());
94        let p2p_msg: p2p::Message = ProstMessage::decode(buf).map_err(|e| {
95            Error::new(
96                ErrorKind::InvalidData,
97                format!("failed prost::Message::decode '{}'", e),
98            )
99        })?;
100
101        match p2p_msg.message.unwrap() {
102            // was not compressed
103            p2p::message::Message::PullQuery(msg) => Ok(Message {
104                msg,
105                gzip_compress: false,
106            }),
107
108            // was compressed, so need decompress first
109            p2p::message::Message::CompressedGzip(msg) => {
110                let decompressed = message::compress::unpack_gzip(msg.as_ref())?;
111                let decompressed_msg: p2p::Message =
112                    ProstMessage::decode(prost::bytes::Bytes::from(decompressed)).map_err(|e| {
113                        Error::new(
114                            ErrorKind::InvalidData,
115                            format!("failed prost::Message::decode '{}'", e),
116                        )
117                    })?;
118                match decompressed_msg.message.unwrap() {
119                    p2p::message::Message::PullQuery(msg) => Ok(Message {
120                        msg,
121                        gzip_compress: false,
122                    }),
123                    _ => Err(Error::new(
124                        ErrorKind::InvalidInput,
125                        "unknown message type after decompress",
126                    )),
127                }
128            }
129
130            // unknown message enum
131            _ => Err(Error::new(ErrorKind::InvalidInput, "unknown message type")),
132        }
133    }
134}
135
136/// RUST_LOG=debug cargo test --package avalanche-types --lib -- message::pull_query::test_message --exact --show-output
137#[test]
138fn test_message() {
139    let _ = env_logger::builder()
140        .filter_level(log::LevelFilter::Debug)
141        .is_test(true)
142        .try_init();
143
144    let msg1_with_no_compression = Message::default()
145        .chain_id(ids::Id::from_slice(
146            &random_manager::secure_bytes(32).unwrap(),
147        ))
148        .request_id(random_manager::u32())
149        .deadline(random_manager::u64())
150        .container_id(ids::Id::from_slice(
151            &random_manager::secure_bytes(32).unwrap(),
152        ));
153
154    let data1 = msg1_with_no_compression.serialize().unwrap();
155    let msg1_with_no_compression_deserialized = Message::deserialize(data1).unwrap();
156    assert_eq!(
157        msg1_with_no_compression,
158        msg1_with_no_compression_deserialized
159    );
160
161    let msg2_with_compression = msg1_with_no_compression.clone().gzip_compress(true);
162    assert_ne!(msg1_with_no_compression, msg2_with_compression);
163
164    let data2 = msg2_with_compression.serialize().unwrap();
165    let msg2_with_compression_deserialized = Message::deserialize(data2).unwrap();
166    assert_eq!(msg1_with_no_compression, msg2_with_compression_deserialized);
167}