libp2p_broadcast/
protocol.rs1use futures::future::BoxFuture;
2use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
3use libp2p::core::{upgrade, InboundUpgrade, OutboundUpgrade, UpgradeInfo};
4use std::io::{Error, ErrorKind, Result};
5use std::sync::Arc;
6
7const PROTOCOL_INFO: &[u8] = b"/ax/broadcast/1.0.0";
8
9#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
10pub struct Topic {
11 len: u8,
12 bytes: [u8; 64],
13}
14
15impl Topic {
16 pub const MAX_TOPIC_LENGTH: usize = 64;
17
18 pub fn new(topic: &[u8]) -> Self {
19 let mut bytes = [0u8; 64];
20 bytes[..topic.len()].copy_from_slice(topic);
21 Self {
22 len: topic.len() as _,
23 bytes,
24 }
25 }
26}
27
28impl std::ops::Deref for Topic {
29 type Target = [u8];
30
31 fn deref(&self) -> &Self::Target {
32 self.as_ref()
33 }
34}
35
36impl AsRef<[u8]> for Topic {
37 fn as_ref(&self) -> &[u8] {
38 &self.bytes[..(self.len as usize)]
39 }
40}
41
42#[derive(Clone, Debug, PartialEq, Eq)]
43pub enum Message {
44 Subscribe(Topic),
45 Broadcast(Topic, Arc<[u8]>),
46 Unsubscribe(Topic),
47}
48
49impl Message {
50 fn from_bytes(bytes: &[u8]) -> Result<Self> {
51 if bytes.is_empty() {
52 return Err(Error::new(ErrorKind::InvalidData, "empty message"));
53 }
54 let topic_len = (bytes[0] >> 2) as usize;
55 if bytes.len() < topic_len + 1 {
56 return Err(Error::new(
57 ErrorKind::InvalidData,
58 "topic length out of range",
59 ));
60 }
61 let msg_len = bytes.len() - topic_len - 1;
62 let topic = Topic::new(&bytes[1..topic_len + 1]);
63 Ok(match bytes[0] & 0b11 {
64 0b00 => Message::Subscribe(topic),
65 0b10 => Message::Unsubscribe(topic),
66 0b01 => {
67 let mut msg = Vec::with_capacity(msg_len);
68 msg.extend_from_slice(&bytes[(topic_len + 1)..]);
69 Message::Broadcast(topic, msg.into())
70 }
71 _ => return Err(Error::new(ErrorKind::InvalidData, "invalid header")),
72 })
73 }
74
75 fn to_bytes(&self) -> Vec<u8> {
76 use Message::*;
77 match self {
78 Subscribe(topic) => {
79 let mut buf = Vec::with_capacity(topic.len() + 1);
80 buf.push((topic.len() as u8) << 2);
81 buf.extend_from_slice(topic);
82 buf
83 }
84 Unsubscribe(topic) => {
85 let mut buf = Vec::with_capacity(topic.len() + 1);
86 buf.push((topic.len() as u8) << 2 | 0b10);
87 buf.extend_from_slice(topic);
88 buf
89 }
90 Broadcast(topic, msg) => {
91 let mut buf = Vec::with_capacity(topic.len() + msg.len() + 1);
92 buf.push((topic.len() as u8) << 2 | 0b01);
93 buf.extend_from_slice(topic);
94 buf.extend_from_slice(msg);
95 buf
96 }
97 }
98 }
99}
100
101#[derive(Clone, Debug)]
102pub struct BroadcastConfig {
103 max_buf_size: usize,
104}
105
106impl Default for BroadcastConfig {
107 fn default() -> Self {
108 Self {
109 max_buf_size: 1024 * 1024 * 4,
110 }
111 }
112}
113
114impl UpgradeInfo for BroadcastConfig {
115 type Info = &'static [u8];
116 type InfoIter = std::iter::Once<Self::Info>;
117
118 fn protocol_info(&self) -> Self::InfoIter {
119 std::iter::once(PROTOCOL_INFO)
120 }
121}
122
123impl<TSocket> InboundUpgrade<TSocket> for BroadcastConfig
124where
125 TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static,
126{
127 type Output = Message;
128 type Error = Error;
129 type Future = BoxFuture<'static, Result<Self::Output>>;
130
131 fn upgrade_inbound(self, mut socket: TSocket, _info: Self::Info) -> Self::Future {
132 Box::pin(async move {
133 let packet = upgrade::read_length_prefixed(&mut socket, self.max_buf_size).await?;
134 socket.close().await?;
135 let request = Message::from_bytes(&packet)?;
136 Ok(request)
137 })
138 }
139}
140
141impl UpgradeInfo for Message {
142 type Info = &'static [u8];
143 type InfoIter = std::iter::Once<Self::Info>;
144
145 fn protocol_info(&self) -> Self::InfoIter {
146 std::iter::once(PROTOCOL_INFO)
147 }
148}
149
150impl<TSocket> OutboundUpgrade<TSocket> for Message
151where
152 TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static,
153{
154 type Output = ();
155 type Error = Error;
156 type Future = BoxFuture<'static, Result<Self::Output>>;
157
158 fn upgrade_outbound(self, mut socket: TSocket, _info: Self::Info) -> Self::Future {
159 Box::pin(async move {
160 let bytes = self.to_bytes();
161 upgrade::write_length_prefixed(&mut socket, bytes).await?;
162 socket.close().await?;
163 Ok(())
164 })
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 #[test]
173 fn test_roundtrip() {
174 let topic = Topic::new(b"topic");
175 let msgs = [
176 Message::Broadcast(Topic::new(b""), Arc::new(*b"")),
177 Message::Subscribe(topic),
178 Message::Unsubscribe(topic),
179 Message::Broadcast(topic, Arc::new(*b"content")),
180 ];
181 for msg in &msgs {
182 let msg2 = Message::from_bytes(&msg.to_bytes()).unwrap();
183 assert_eq!(msg, &msg2);
184 }
185 }
186
187 #[test]
188 #[should_panic]
189 fn test_invalid_message() {
190 let out_of_range = [0b0000_0100];
191 Message::from_bytes(&out_of_range).unwrap();
192 }
193}