libp2p_scatter/
protocol.rs1use std::fmt;
2use std::io::{Error, ErrorKind, Result};
3
4use bytes::Bytes;
5use futures::future::BoxFuture;
6use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
7use libp2p::core::{InboundUpgrade, OutboundUpgrade, UpgradeInfo};
8use prometheus_client::encoding::{EncodeLabelSet, LabelSetEncoder};
9
10use crate::length_prefixed::{read_length_prefixed, write_length_prefixed};
11
12const PROTOCOL_INFO: &str = "/ax/broadcast/1.0.0";
13
14#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
15pub struct Topic {
16 len: u8,
17 bytes: [u8; 64],
18}
19
20impl Topic {
21 pub const MAX_TOPIC_LENGTH: usize = 64;
22
23 pub fn new(topic: &[u8]) -> Self {
24 let mut bytes = [0u8; 64];
25 bytes[..topic.len()].copy_from_slice(topic);
26 Self {
27 len: topic.len() as _,
28 bytes,
29 }
30 }
31}
32
33impl EncodeLabelSet for Topic {
34 fn encode(&self, mut encoder: LabelSetEncoder) -> fmt::Result {
35 use prometheus_client::encoding::{EncodeLabelKey, EncodeLabelValue};
36
37 let mut label_encoder = encoder.encode_label();
38 let mut key_encoder = label_encoder.encode_label_key()?;
39 EncodeLabelKey::encode(&"topic", &mut key_encoder)?;
40 let mut value_encoder = key_encoder.encode_label_value()?;
41 let value = String::from_utf8_lossy(self.as_ref());
42 EncodeLabelValue::encode(&value, &mut value_encoder)?;
43 value_encoder.finish()
44 }
45}
46
47impl std::ops::Deref for Topic {
48 type Target = [u8];
49
50 fn deref(&self) -> &Self::Target {
51 self.as_ref()
52 }
53}
54
55impl AsRef<[u8]> for Topic {
56 fn as_ref(&self) -> &[u8] {
57 &self.bytes[..(self.len as usize)]
58 }
59}
60
61#[derive(Clone, Debug, PartialEq, Eq)]
62pub enum Message {
63 Subscribe(Topic),
64 Broadcast(Topic, Bytes),
65 Unsubscribe(Topic),
66}
67
68impl Message {
69 pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
70 if bytes.is_empty() {
71 return Err(Error::new(ErrorKind::InvalidData, "empty message"));
72 }
73 let topic_len = (bytes[0] >> 2) as usize;
74 if bytes.len() < topic_len + 1 {
75 return Err(Error::new(
76 ErrorKind::InvalidData,
77 "topic length out of range",
78 ));
79 }
80 let msg_len = bytes.len() - topic_len - 1;
81 let topic = Topic::new(&bytes[1..topic_len + 1]);
82 Ok(match bytes[0] & 0b11 {
83 0b00 => Message::Subscribe(topic),
84 0b10 => Message::Unsubscribe(topic),
85 0b01 => {
86 let mut msg = Vec::with_capacity(msg_len);
87 msg.extend_from_slice(&bytes[(topic_len + 1)..]);
88 Message::Broadcast(topic, msg.into())
89 }
90 _ => return Err(Error::new(ErrorKind::InvalidData, "invalid header")),
91 })
92 }
93
94 pub fn to_bytes(&self) -> Vec<u8> {
95 match self {
96 Message::Subscribe(topic) => {
97 let mut buf = Vec::with_capacity(topic.len() + 1);
98 buf.push((topic.len() as u8) << 2);
99 buf.extend_from_slice(topic);
100 buf
101 }
102 Message::Unsubscribe(topic) => {
103 let mut buf = Vec::with_capacity(topic.len() + 1);
104 buf.push((topic.len() as u8) << 2 | 0b10);
105 buf.extend_from_slice(topic);
106 buf
107 }
108 Message::Broadcast(topic, msg) => {
109 let mut buf = Vec::with_capacity(topic.len() + msg.len() + 1);
110 buf.push((topic.len() as u8) << 2 | 0b01);
111 buf.extend_from_slice(topic);
112 buf.extend_from_slice(msg);
113 buf
114 }
115 }
116 }
117
118 pub fn len(&self) -> usize {
119 match self {
120 Message::Subscribe(topic) => 1 + topic.len(),
121 Message::Unsubscribe(topic) => 1 + topic.len(),
122 Message::Broadcast(topic, msg) => 1 + topic.len() + msg.len(),
123 }
124 }
125}
126
127#[derive(Clone, Debug)]
128pub struct Config {
129 pub max_buf_size: usize,
130}
131
132impl Config {
133 pub fn max_buf_size(mut self, max_buf_size: usize) -> Self {
134 self.max_buf_size = max_buf_size;
135 self
136 }
137}
138
139impl Default for Config {
140 fn default() -> Self {
141 Self {
142 max_buf_size: 1024 * 1024 * 4, }
144 }
145}
146
147impl UpgradeInfo for Config {
148 type Info = &'static str;
149 type InfoIter = std::iter::Once<Self::Info>;
150
151 fn protocol_info(&self) -> Self::InfoIter {
152 std::iter::once(PROTOCOL_INFO)
153 }
154}
155
156impl<TSocket> InboundUpgrade<TSocket> for Config
157where
158 TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static,
159{
160 type Output = Message;
161 type Error = Error;
162 type Future = BoxFuture<'static, Result<Self::Output>>;
163
164 fn upgrade_inbound(self, mut socket: TSocket, _info: Self::Info) -> Self::Future {
165 Box::pin(async move {
166 let packet = read_length_prefixed(&mut socket, self.max_buf_size).await?;
167 socket.close().await?;
168 let request = Message::from_bytes(&packet)?;
169 Ok(request)
170 })
171 }
172}
173
174impl UpgradeInfo for Message {
175 type Info = &'static str;
176 type InfoIter = std::iter::Once<Self::Info>;
177
178 fn protocol_info(&self) -> Self::InfoIter {
179 std::iter::once(PROTOCOL_INFO)
180 }
181}
182
183impl<TSocket> OutboundUpgrade<TSocket> for Message
184where
185 TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static,
186{
187 type Output = ();
188 type Error = Error;
189 type Future = BoxFuture<'static, Result<Self::Output>>;
190
191 fn upgrade_outbound(self, mut socket: TSocket, _info: Self::Info) -> Self::Future {
192 Box::pin(async move {
193 let bytes = self.to_bytes();
194 write_length_prefixed(&mut socket, bytes).await?;
195 socket.close().await?;
196 Ok(())
197 })
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 #[test]
206 fn test_roundtrip() {
207 let topic = Topic::new(b"topic");
208 let msgs = [
209 Message::Broadcast(Topic::new(b""), Bytes::from_static(b"")),
210 Message::Subscribe(topic),
211 Message::Unsubscribe(topic),
212 Message::Broadcast(topic, Bytes::from_static(b"content")),
213 ];
214 for msg in &msgs {
215 let msg2 = Message::from_bytes(&msg.to_bytes()).unwrap();
216 assert_eq!(msg, &msg2);
217 }
218 }
219
220 #[test]
221 #[should_panic]
222 fn test_invalid_message() {
223 let out_of_range = [0b0000_0100];
224 Message::from_bytes(&out_of_range).unwrap();
225 }
226}