libp2p_scatter/
protocol.rs

1use std::fmt;
2use std::io::{Error, ErrorKind, Result};
3
4use bytes::Bytes;
5use futures::future::BoxFuture;
6use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
7use libp2p::core::{InboundUpgrade, OutboundUpgrade, UpgradeInfo};
8use prometheus_client::encoding::{EncodeLabelSet, LabelSetEncoder};
9
10use crate::length_prefixed::{read_length_prefixed, write_length_prefixed};
11
12const PROTOCOL_INFO: &str = "/ax/broadcast/1.0.0";
13
14#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
15pub struct Topic {
16    len: u8,
17    bytes: [u8; 64],
18}
19
20impl Topic {
21    pub const MAX_TOPIC_LENGTH: usize = 64;
22
23    pub fn new(topic: &[u8]) -> Self {
24        let mut bytes = [0u8; 64];
25        bytes[..topic.len()].copy_from_slice(topic);
26        Self {
27            len: topic.len() as _,
28            bytes,
29        }
30    }
31}
32
33impl EncodeLabelSet for Topic {
34    fn encode(&self, mut encoder: LabelSetEncoder) -> fmt::Result {
35        use prometheus_client::encoding::{EncodeLabelKey, EncodeLabelValue};
36
37        let mut label_encoder = encoder.encode_label();
38        let mut key_encoder = label_encoder.encode_label_key()?;
39        EncodeLabelKey::encode(&"topic", &mut key_encoder)?;
40        let mut value_encoder = key_encoder.encode_label_value()?;
41        let value = String::from_utf8_lossy(self.as_ref());
42        EncodeLabelValue::encode(&value, &mut value_encoder)?;
43        value_encoder.finish()
44    }
45}
46
47impl std::ops::Deref for Topic {
48    type Target = [u8];
49
50    fn deref(&self) -> &Self::Target {
51        self.as_ref()
52    }
53}
54
55impl AsRef<[u8]> for Topic {
56    fn as_ref(&self) -> &[u8] {
57        &self.bytes[..(self.len as usize)]
58    }
59}
60
61#[derive(Clone, Debug, PartialEq, Eq)]
62pub enum Message {
63    Subscribe(Topic),
64    Broadcast(Topic, Bytes),
65    Unsubscribe(Topic),
66}
67
68impl Message {
69    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
70        if bytes.is_empty() {
71            return Err(Error::new(ErrorKind::InvalidData, "empty message"));
72        }
73        let topic_len = (bytes[0] >> 2) as usize;
74        if bytes.len() < topic_len + 1 {
75            return Err(Error::new(
76                ErrorKind::InvalidData,
77                "topic length out of range",
78            ));
79        }
80        let msg_len = bytes.len() - topic_len - 1;
81        let topic = Topic::new(&bytes[1..topic_len + 1]);
82        Ok(match bytes[0] & 0b11 {
83            0b00 => Message::Subscribe(topic),
84            0b10 => Message::Unsubscribe(topic),
85            0b01 => {
86                let mut msg = Vec::with_capacity(msg_len);
87                msg.extend_from_slice(&bytes[(topic_len + 1)..]);
88                Message::Broadcast(topic, msg.into())
89            }
90            _ => return Err(Error::new(ErrorKind::InvalidData, "invalid header")),
91        })
92    }
93
94    pub fn to_bytes(&self) -> Vec<u8> {
95        match self {
96            Message::Subscribe(topic) => {
97                let mut buf = Vec::with_capacity(topic.len() + 1);
98                buf.push((topic.len() as u8) << 2);
99                buf.extend_from_slice(topic);
100                buf
101            }
102            Message::Unsubscribe(topic) => {
103                let mut buf = Vec::with_capacity(topic.len() + 1);
104                buf.push((topic.len() as u8) << 2 | 0b10);
105                buf.extend_from_slice(topic);
106                buf
107            }
108            Message::Broadcast(topic, msg) => {
109                let mut buf = Vec::with_capacity(topic.len() + msg.len() + 1);
110                buf.push((topic.len() as u8) << 2 | 0b01);
111                buf.extend_from_slice(topic);
112                buf.extend_from_slice(msg);
113                buf
114            }
115        }
116    }
117
118    pub fn len(&self) -> usize {
119        match self {
120            Message::Subscribe(topic) => 1 + topic.len(),
121            Message::Unsubscribe(topic) => 1 + topic.len(),
122            Message::Broadcast(topic, msg) => 1 + topic.len() + msg.len(),
123        }
124    }
125}
126
127#[derive(Clone, Debug)]
128pub struct Config {
129    pub max_buf_size: usize,
130}
131
132impl Config {
133    pub fn max_buf_size(mut self, max_buf_size: usize) -> Self {
134        self.max_buf_size = max_buf_size;
135        self
136    }
137}
138
139impl Default for Config {
140    fn default() -> Self {
141        Self {
142            max_buf_size: 1024 * 1024 * 4, // 4 MiB
143        }
144    }
145}
146
147impl UpgradeInfo for Config {
148    type Info = &'static str;
149    type InfoIter = std::iter::Once<Self::Info>;
150
151    fn protocol_info(&self) -> Self::InfoIter {
152        std::iter::once(PROTOCOL_INFO)
153    }
154}
155
156impl<TSocket> InboundUpgrade<TSocket> for Config
157where
158    TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static,
159{
160    type Output = Message;
161    type Error = Error;
162    type Future = BoxFuture<'static, Result<Self::Output>>;
163
164    fn upgrade_inbound(self, mut socket: TSocket, _info: Self::Info) -> Self::Future {
165        Box::pin(async move {
166            let packet = read_length_prefixed(&mut socket, self.max_buf_size).await?;
167            socket.close().await?;
168            let request = Message::from_bytes(&packet)?;
169            Ok(request)
170        })
171    }
172}
173
174impl UpgradeInfo for Message {
175    type Info = &'static str;
176    type InfoIter = std::iter::Once<Self::Info>;
177
178    fn protocol_info(&self) -> Self::InfoIter {
179        std::iter::once(PROTOCOL_INFO)
180    }
181}
182
183impl<TSocket> OutboundUpgrade<TSocket> for Message
184where
185    TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static,
186{
187    type Output = ();
188    type Error = Error;
189    type Future = BoxFuture<'static, Result<Self::Output>>;
190
191    fn upgrade_outbound(self, mut socket: TSocket, _info: Self::Info) -> Self::Future {
192        Box::pin(async move {
193            let bytes = self.to_bytes();
194            write_length_prefixed(&mut socket, bytes).await?;
195            socket.close().await?;
196            Ok(())
197        })
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204
205    #[test]
206    fn test_roundtrip() {
207        let topic = Topic::new(b"topic");
208        let msgs = [
209            Message::Broadcast(Topic::new(b""), Bytes::from_static(b"")),
210            Message::Subscribe(topic),
211            Message::Unsubscribe(topic),
212            Message::Broadcast(topic, Bytes::from_static(b"content")),
213        ];
214        for msg in &msgs {
215            let msg2 = Message::from_bytes(&msg.to_bytes()).unwrap();
216            assert_eq!(msg, &msg2);
217        }
218    }
219
220    #[test]
221    #[should_panic]
222    fn test_invalid_message() {
223        let out_of_range = [0b0000_0100];
224        Message::from_bytes(&out_of_range).unwrap();
225    }
226}