libp2p_broadcast/
protocol.rs

1use futures::future::BoxFuture;
2use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
3use libp2p::core::{upgrade, InboundUpgrade, OutboundUpgrade, UpgradeInfo};
4use std::io::{Error, ErrorKind, Result};
5use std::sync::Arc;
6
7const PROTOCOL_INFO: &[u8] = b"/ax/broadcast/1.0.0";
8
9#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
10pub struct Topic {
11    len: u8,
12    bytes: [u8; 64],
13}
14
15impl Topic {
16    pub const MAX_TOPIC_LENGTH: usize = 64;
17
18    pub fn new(topic: &[u8]) -> Self {
19        let mut bytes = [0u8; 64];
20        bytes[..topic.len()].copy_from_slice(topic);
21        Self {
22            len: topic.len() as _,
23            bytes,
24        }
25    }
26}
27
28impl std::ops::Deref for Topic {
29    type Target = [u8];
30
31    fn deref(&self) -> &Self::Target {
32        self.as_ref()
33    }
34}
35
36impl AsRef<[u8]> for Topic {
37    fn as_ref(&self) -> &[u8] {
38        &self.bytes[..(self.len as usize)]
39    }
40}
41
42#[derive(Clone, Debug, PartialEq, Eq)]
43pub enum Message {
44    Subscribe(Topic),
45    Broadcast(Topic, Arc<[u8]>),
46    Unsubscribe(Topic),
47}
48
49impl Message {
50    fn from_bytes(bytes: &[u8]) -> Result<Self> {
51        if bytes.is_empty() {
52            return Err(Error::new(ErrorKind::InvalidData, "empty message"));
53        }
54        let topic_len = (bytes[0] >> 2) as usize;
55        if bytes.len() < topic_len + 1 {
56            return Err(Error::new(
57                ErrorKind::InvalidData,
58                "topic length out of range",
59            ));
60        }
61        let msg_len = bytes.len() - topic_len - 1;
62        let topic = Topic::new(&bytes[1..topic_len + 1]);
63        Ok(match bytes[0] & 0b11 {
64            0b00 => Message::Subscribe(topic),
65            0b10 => Message::Unsubscribe(topic),
66            0b01 => {
67                let mut msg = Vec::with_capacity(msg_len);
68                msg.extend_from_slice(&bytes[(topic_len + 1)..]);
69                Message::Broadcast(topic, msg.into())
70            }
71            _ => return Err(Error::new(ErrorKind::InvalidData, "invalid header")),
72        })
73    }
74
75    fn to_bytes(&self) -> Vec<u8> {
76        use Message::*;
77        match self {
78            Subscribe(topic) => {
79                let mut buf = Vec::with_capacity(topic.len() + 1);
80                buf.push((topic.len() as u8) << 2);
81                buf.extend_from_slice(topic);
82                buf
83            }
84            Unsubscribe(topic) => {
85                let mut buf = Vec::with_capacity(topic.len() + 1);
86                buf.push((topic.len() as u8) << 2 | 0b10);
87                buf.extend_from_slice(topic);
88                buf
89            }
90            Broadcast(topic, msg) => {
91                let mut buf = Vec::with_capacity(topic.len() + msg.len() + 1);
92                buf.push((topic.len() as u8) << 2 | 0b01);
93                buf.extend_from_slice(topic);
94                buf.extend_from_slice(msg);
95                buf
96            }
97        }
98    }
99}
100
101#[derive(Clone, Debug)]
102pub struct BroadcastConfig {
103    max_buf_size: usize,
104}
105
106impl Default for BroadcastConfig {
107    fn default() -> Self {
108        Self {
109            max_buf_size: 1024 * 1024 * 4,
110        }
111    }
112}
113
114impl UpgradeInfo for BroadcastConfig {
115    type Info = &'static [u8];
116    type InfoIter = std::iter::Once<Self::Info>;
117
118    fn protocol_info(&self) -> Self::InfoIter {
119        std::iter::once(PROTOCOL_INFO)
120    }
121}
122
123impl<TSocket> InboundUpgrade<TSocket> for BroadcastConfig
124where
125    TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static,
126{
127    type Output = Message;
128    type Error = Error;
129    type Future = BoxFuture<'static, Result<Self::Output>>;
130
131    fn upgrade_inbound(self, mut socket: TSocket, _info: Self::Info) -> Self::Future {
132        Box::pin(async move {
133            let packet = upgrade::read_length_prefixed(&mut socket, self.max_buf_size).await?;
134            socket.close().await?;
135            let request = Message::from_bytes(&packet)?;
136            Ok(request)
137        })
138    }
139}
140
141impl UpgradeInfo for Message {
142    type Info = &'static [u8];
143    type InfoIter = std::iter::Once<Self::Info>;
144
145    fn protocol_info(&self) -> Self::InfoIter {
146        std::iter::once(PROTOCOL_INFO)
147    }
148}
149
150impl<TSocket> OutboundUpgrade<TSocket> for Message
151where
152    TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static,
153{
154    type Output = ();
155    type Error = Error;
156    type Future = BoxFuture<'static, Result<Self::Output>>;
157
158    fn upgrade_outbound(self, mut socket: TSocket, _info: Self::Info) -> Self::Future {
159        Box::pin(async move {
160            let bytes = self.to_bytes();
161            upgrade::write_length_prefixed(&mut socket, bytes).await?;
162            socket.close().await?;
163            Ok(())
164        })
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    #[test]
173    fn test_roundtrip() {
174        let topic = Topic::new(b"topic");
175        let msgs = [
176            Message::Broadcast(Topic::new(b""), Arc::new(*b"")),
177            Message::Subscribe(topic),
178            Message::Unsubscribe(topic),
179            Message::Broadcast(topic, Arc::new(*b"content")),
180        ];
181        for msg in &msgs {
182            let msg2 = Message::from_bytes(&msg.to_bytes()).unwrap();
183            assert_eq!(msg, &msg2);
184        }
185    }
186
187    #[test]
188    #[should_panic]
189    fn test_invalid_message() {
190        let out_of_range = [0b0000_0100];
191        Message::from_bytes(&out_of_range).unwrap();
192    }
193}