beetle_bitswap_next/
protocol.rs1use std::fmt;
2use std::future::Future;
3use std::pin::Pin;
4
5use asynchronous_codec::{Decoder, Encoder, Framed};
6use bytes::{Bytes, BytesMut};
7use futures::future;
8use futures::io::{AsyncRead, AsyncWrite};
9use libp2p::core::{InboundUpgrade, OutboundUpgrade, UpgradeInfo};
10use libp2p::StreamProtocol;
11use quick_protobuf::{MessageWrite, Writer};
12use unsigned_varint::codec;
13
14use crate::{handler::BitswapHandlerError, message::BitswapMessage};
15
16const MAX_BUF_SIZE: usize = 1024 * 1024 * 2;
17
18#[derive(Clone, Debug, Copy, PartialEq, Eq, PartialOrd, Ord)]
19pub enum ProtocolId {
20 Legacy = 0,
21 Bitswap100 = 1,
22 Bitswap110 = 2,
23 Bitswap120 = 3,
24}
25
26const BITSWAP_LEGACY: StreamProtocol = StreamProtocol::new("/ipfs/bitswap");
27const BITSWAP_100: StreamProtocol = StreamProtocol::new("/ipfs/bitswap/1.0.0");
28const BITSWAP_110: StreamProtocol = StreamProtocol::new("/ipfs/bitswap/1.1.0");
29const BITSWAP_120: StreamProtocol = StreamProtocol::new("/ipfs/bitswap/1.2.0");
30
31impl ProtocolId {
32 pub fn try_from(value: impl AsRef<str>) -> Option<Self> {
33 let value = value.as_ref();
34
35 if value == BITSWAP_LEGACY {
36 Some(ProtocolId::Legacy)
37 } else if value == BITSWAP_100 {
38 Some(ProtocolId::Bitswap100)
39 } else if value == BITSWAP_110 {
40 Some(ProtocolId::Bitswap110)
41 } else if value == BITSWAP_120 {
42 Some(ProtocolId::Bitswap120)
43 } else {
44 None
45 }
46 }
47}
48
49impl AsRef<str> for ProtocolId {
50 fn as_ref(&self) -> &str {
51 match *self {
52 ProtocolId::Legacy => "/ipfs/bitswap",
53 ProtocolId::Bitswap100 => "/ipfs/bitswap/1.0.0",
54 ProtocolId::Bitswap110 => "/ipfs/bitswap/1.1.0",
55 ProtocolId::Bitswap120 => "/ipfs/bitswap/1.2.0",
56 }
57 }
58}
59
60impl From<ProtocolId> for StreamProtocol {
61 fn from(protocol: ProtocolId) -> Self {
62 match protocol {
63 ProtocolId::Legacy => BITSWAP_LEGACY,
64 ProtocolId::Bitswap100 => BITSWAP_100,
65 ProtocolId::Bitswap110 => BITSWAP_110,
66 ProtocolId::Bitswap120 => BITSWAP_120,
67 }
68 }
69}
70
71impl ProtocolId {
72 pub fn supports_have(self) -> bool {
73 matches!(self, ProtocolId::Bitswap120)
74 }
75}
76
77#[derive(Clone, Debug, PartialEq, Eq)]
78pub struct ProtocolConfig {
79 pub protocol_ids: Vec<ProtocolId>,
81 pub max_transmit_size: usize,
83}
84
85impl Default for ProtocolConfig {
86 fn default() -> Self {
87 ProtocolConfig {
88 protocol_ids: vec![
89 ProtocolId::Bitswap120,
90 ProtocolId::Bitswap110,
91 ProtocolId::Bitswap100,
92 ProtocolId::Legacy,
93 ],
94 max_transmit_size: MAX_BUF_SIZE,
95 }
96 }
97}
98
99impl UpgradeInfo for ProtocolConfig {
100 type Info = ProtocolId;
101 type InfoIter = Vec<Self::Info>;
102
103 fn protocol_info(&self) -> Self::InfoIter {
104 self.protocol_ids.clone()
105 }
106}
107
108impl<TSocket> InboundUpgrade<TSocket> for ProtocolConfig
109where
110 TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static,
111{
112 type Output = Framed<TSocket, BitswapCodec>;
113 type Error = BitswapHandlerError;
114 #[allow(clippy::type_complexity)]
115 type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
116
117 #[inline]
118 fn upgrade_inbound(self, socket: TSocket, protocol_id: Self::Info) -> Self::Future {
119 let mut length_codec = codec::UviBytes::default();
120 length_codec.set_max_len(self.max_transmit_size);
121 Box::pin(future::ok(Framed::new(
122 socket,
123 BitswapCodec::new(length_codec, protocol_id),
124 )))
125 }
126}
127
128impl<TSocket> OutboundUpgrade<TSocket> for ProtocolConfig
129where
130 TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static,
131{
132 type Output = Framed<TSocket, BitswapCodec>;
133 type Error = BitswapHandlerError;
134 #[allow(clippy::type_complexity)]
135 type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
136
137 #[inline]
138 fn upgrade_outbound(self, socket: TSocket, protocol_id: Self::Info) -> Self::Future {
139 let mut length_codec = codec::UviBytes::default();
140 length_codec.set_max_len(self.max_transmit_size);
141 Box::pin(future::ok(Framed::new(
142 socket,
143 BitswapCodec::new(length_codec, protocol_id),
144 )))
145 }
146}
147
148pub struct BitswapCodec {
150 pub length_codec: codec::UviBytes,
152 pub protocol: ProtocolId,
153}
154
155impl fmt::Debug for BitswapCodec {
156 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157 f.debug_struct("BitswapCodec")
158 .field("length_codec", &"unsigned_varint::codec::UviBytes")
159 .field("protocol", &self.protocol)
160 .finish()
161 }
162}
163
164impl BitswapCodec {
165 pub fn new(length_codec: codec::UviBytes, protocol: ProtocolId) -> Self {
166 BitswapCodec {
167 length_codec,
168 protocol,
169 }
170 }
171}
172
173impl Encoder for BitswapCodec {
174 type Item = BitswapMessage;
175 type Error = BitswapHandlerError;
176
177 fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
178 tracing::trace!("sending message protocol: {:?}\n{:?}", self.protocol, item);
179
180 let message = match self.protocol {
181 ProtocolId::Legacy | ProtocolId::Bitswap100 => item.encode_as_proto_v0(),
182 ProtocolId::Bitswap110 | ProtocolId::Bitswap120 => item.encode_as_proto_v1(),
183 };
184 let mut buf = Vec::with_capacity(message.get_size());
185 let mut writer = Writer::new(&mut buf);
186
187 message.write_message(&mut writer).expect("fixed target");
188
189 self.length_codec
191 .encode(Bytes::from(buf), dst)
192 .map_err(|_| BitswapHandlerError::MaxTransmissionSize)
193 }
194}
195
196impl Decoder for BitswapCodec {
197 type Item = (BitswapMessage, ProtocolId);
198 type Error = BitswapHandlerError;
199
200 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
201 let packet = match self.length_codec.decode(src).map_err(|e| {
202 if let std::io::ErrorKind::PermissionDenied = e.kind() {
203 BitswapHandlerError::MaxTransmissionSize
204 } else {
205 BitswapHandlerError::Io(e)
206 }
207 })? {
208 Some(p) => p,
209 None => return Ok(None),
210 };
211
212 let message = BitswapMessage::try_from(packet.freeze())?;
213
214 Ok(Some((message, self.protocol)))
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
226
227 #[test]
255 fn test_ord() {
256 let mut protocols = [
257 ProtocolId::Bitswap120,
258 ProtocolId::Bitswap100,
259 ProtocolId::Legacy,
260 ];
261 protocols.sort();
262 assert_eq!(
263 protocols,
264 [
265 ProtocolId::Legacy,
266 ProtocolId::Bitswap100,
267 ProtocolId::Bitswap120
268 ]
269 );
270 }
271}