libp2p_bitswap/
protocol.rs1use async_trait::async_trait;
2use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
3use libipld::cid::Cid;
4use libipld::store::StoreParams;
5use libp2p::request_response::{ProtocolName, RequestResponseCodec};
6use std::convert::TryFrom;
7use std::io::{self, Write};
8use std::marker::PhantomData;
9use thiserror::Error;
10use unsigned_varint::{aio, io::ReadError};
11
12const MAX_CID_SIZE: usize = 4 * 10 + 64;
14
15#[derive(Clone, Debug)]
16pub struct BitswapProtocol;
17
18impl ProtocolName for BitswapProtocol {
19 fn protocol_name(&self) -> &[u8] {
20 b"/ipfs-embed/bitswap/1.0.0"
21 }
22}
23
24#[derive(Clone)]
25pub struct BitswapCodec<P> {
26 _marker: PhantomData<P>,
27 buffer: Vec<u8>,
28}
29
30impl<P: StoreParams> Default for BitswapCodec<P> {
31 fn default() -> Self {
32 let capacity = usize::max(P::MAX_BLOCK_SIZE, MAX_CID_SIZE) + 1;
33 debug_assert!(capacity <= u32::MAX as usize);
34 Self {
35 _marker: PhantomData,
36 buffer: Vec::with_capacity(capacity),
37 }
38 }
39}
40
41#[async_trait]
42impl<P: StoreParams> RequestResponseCodec for BitswapCodec<P> {
43 type Protocol = BitswapProtocol;
44 type Request = BitswapRequest;
45 type Response = BitswapResponse;
46
47 async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Self::Request>
48 where
49 T: AsyncRead + Send + Unpin,
50 {
51 let msg_len = u32_to_usize(aio::read_u32(&mut *io).await.map_err(|e| match e {
52 ReadError::Io(e) => e,
53 err => other(err),
54 })?);
55 if msg_len > MAX_CID_SIZE + 1 {
56 return Err(invalid_data(MessageTooLarge(msg_len)));
57 }
58 self.buffer.resize(msg_len, 0);
59 io.read_exact(&mut self.buffer).await?;
60 let request = BitswapRequest::from_bytes(&self.buffer).map_err(invalid_data)?;
61 Ok(request)
62 }
63
64 async fn read_response<T>(
65 &mut self,
66 _: &Self::Protocol,
67 io: &mut T,
68 ) -> io::Result<Self::Response>
69 where
70 T: AsyncRead + Send + Unpin,
71 {
72 let msg_len = u32_to_usize(aio::read_u32(&mut *io).await.map_err(|e| match e {
73 ReadError::Io(e) => e,
74 err => other(err),
75 })?);
76 if msg_len > P::MAX_BLOCK_SIZE + 1 {
77 return Err(invalid_data(MessageTooLarge(msg_len)));
78 }
79 self.buffer.resize(msg_len, 0);
80 io.read_exact(&mut self.buffer).await?;
81 let response = BitswapResponse::from_bytes(&self.buffer).map_err(invalid_data)?;
82 Ok(response)
83 }
84
85 async fn write_request<T>(
86 &mut self,
87 _: &Self::Protocol,
88 io: &mut T,
89 req: Self::Request,
90 ) -> io::Result<()>
91 where
92 T: AsyncWrite + Send + Unpin,
93 {
94 self.buffer.clear();
95 req.write_to(&mut self.buffer)?;
96 if self.buffer.len() > MAX_CID_SIZE + 1 {
97 return Err(invalid_data(MessageTooLarge(self.buffer.len())));
98 }
99 let mut buf = unsigned_varint::encode::u32_buffer();
100 let msg_len = unsigned_varint::encode::u32(self.buffer.len() as u32, &mut buf);
101 io.write_all(msg_len).await?;
102 io.write_all(&self.buffer).await?;
103 Ok(())
104 }
105
106 async fn write_response<T>(
107 &mut self,
108 _: &Self::Protocol,
109 io: &mut T,
110 res: Self::Response,
111 ) -> io::Result<()>
112 where
113 T: AsyncWrite + Send + Unpin,
114 {
115 self.buffer.clear();
116 res.write_to(&mut self.buffer)?;
117 if self.buffer.len() > P::MAX_BLOCK_SIZE + 1 {
118 return Err(invalid_data(MessageTooLarge(self.buffer.len())));
119 }
120 let mut buf = unsigned_varint::encode::u32_buffer();
121 let msg_len = unsigned_varint::encode::u32(self.buffer.len() as u32, &mut buf);
122 io.write_all(msg_len).await?;
123 io.write_all(&self.buffer).await?;
124 Ok(())
125 }
126}
127
128#[derive(Clone, Copy, Debug, Eq, PartialEq)]
129pub enum RequestType {
130 Have,
131 Block,
132}
133
134#[derive(Clone, Copy, Debug, Eq, PartialEq)]
136pub struct BitswapRequest {
137 pub ty: RequestType,
139 pub cid: Cid,
141}
142
143impl BitswapRequest {
144 pub fn write_to<W: Write>(&self, w: &mut W) -> io::Result<()> {
146 match self {
147 BitswapRequest {
148 ty: RequestType::Have,
149 cid,
150 } => {
151 w.write_all(&[0])?;
152 cid.write_bytes(&mut *w).map_err(other)?;
153 }
154 BitswapRequest {
155 ty: RequestType::Block,
156 cid,
157 } => {
158 w.write_all(&[1])?;
159 cid.write_bytes(&mut *w).map_err(other)?;
160 }
161 }
162 Ok(())
163 }
164
165 pub fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
167 let ty = match bytes[0] {
168 0 => RequestType::Have,
169 1 => RequestType::Block,
170 c => return Err(invalid_data(UnknownMessageType(c))),
171 };
172 let cid = Cid::try_from(&bytes[1..]).map_err(invalid_data)?;
173 Ok(Self { ty, cid })
174 }
175}
176
177#[derive(Clone, Debug, Eq, PartialEq)]
179pub enum BitswapResponse {
180 Have(bool),
182 Block(Vec<u8>),
184}
185
186impl BitswapResponse {
187 pub fn write_to<W: Write>(&self, w: &mut W) -> io::Result<()> {
189 match self {
190 BitswapResponse::Have(have) => {
191 if *have {
192 w.write_all(&[0])?;
193 } else {
194 w.write_all(&[2])?;
195 }
196 }
197 BitswapResponse::Block(data) => {
198 w.write_all(&[1])?;
199 w.write_all(data)?;
200 }
201 };
202 Ok(())
203 }
204
205 pub fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
207 let res = match bytes[0] {
208 0 | 2 => BitswapResponse::Have(bytes[0] == 0),
209 1 => BitswapResponse::Block(bytes[1..].to_vec()),
210 c => return Err(invalid_data(UnknownMessageType(c))),
211 };
212 Ok(res)
213 }
214}
215
216fn invalid_data<E: std::error::Error + Send + Sync + 'static>(e: E) -> io::Error {
217 io::Error::new(io::ErrorKind::InvalidData, e)
218}
219
220fn other<E: std::error::Error + Send + Sync + 'static>(e: E) -> io::Error {
221 io::Error::new(io::ErrorKind::Other, e)
222}
223
224#[cfg(any(target_pointer_width = "64", target_pointer_width = "32"))]
225fn u32_to_usize(n: u32) -> usize {
226 n as usize
227}
228
229#[derive(Debug, Error)]
230#[error("unknown message type {0}")]
231pub struct UnknownMessageType(u8);
232
233#[derive(Debug, Error)]
234#[error("message too large {0}")]
235pub struct MessageTooLarge(usize);
236
237#[cfg(test)]
238pub(crate) mod tests {
239 use super::*;
240 use libipld::multihash::Code;
241 use multihash::MultihashDigest;
242
243 pub fn create_cid(bytes: &[u8]) -> Cid {
244 let digest = Code::Blake3_256.digest(bytes);
245 Cid::new_v1(0x55, digest)
246 }
247
248 #[test]
249 fn test_request_encode_decode() {
250 let requests = [
251 BitswapRequest {
252 ty: RequestType::Have,
253 cid: create_cid(&b"have_request"[..]),
254 },
255 BitswapRequest {
256 ty: RequestType::Block,
257 cid: create_cid(&b"block_request"[..]),
258 },
259 ];
260 let mut buf = Vec::with_capacity(MAX_CID_SIZE + 1);
261 for request in &requests {
262 buf.clear();
263 request.write_to(&mut buf).unwrap();
264 assert_eq!(&BitswapRequest::from_bytes(&buf).unwrap(), request);
265 }
266 }
267
268 #[test]
269 fn test_response_encode_decode() {
270 let responses = [
271 BitswapResponse::Have(true),
272 BitswapResponse::Have(false),
273 BitswapResponse::Block(b"block_response".to_vec()),
274 ];
275 let mut buf = Vec::with_capacity(13 + 1);
276 for response in &responses {
277 buf.clear();
278 response.write_to(&mut buf).unwrap();
279 assert_eq!(&BitswapResponse::from_bytes(&buf).unwrap(), response);
280 }
281 }
282}