libp2p_bitswap/
protocol.rs

1use async_trait::async_trait;
2use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
3use libipld::cid::Cid;
4use libipld::store::StoreParams;
5use libp2p::request_response::{ProtocolName, RequestResponseCodec};
6use std::convert::TryFrom;
7use std::io::{self, Write};
8use std::marker::PhantomData;
9use thiserror::Error;
10use unsigned_varint::{aio, io::ReadError};
11
12// version codec hash size (u64 varint is max 10 bytes) + digest
13const MAX_CID_SIZE: usize = 4 * 10 + 64;
14
15#[derive(Clone, Debug)]
16pub struct BitswapProtocol;
17
18impl ProtocolName for BitswapProtocol {
19    fn protocol_name(&self) -> &[u8] {
20        b"/ipfs-embed/bitswap/1.0.0"
21    }
22}
23
24#[derive(Clone)]
25pub struct BitswapCodec<P> {
26    _marker: PhantomData<P>,
27    buffer: Vec<u8>,
28}
29
30impl<P: StoreParams> Default for BitswapCodec<P> {
31    fn default() -> Self {
32        let capacity = usize::max(P::MAX_BLOCK_SIZE, MAX_CID_SIZE) + 1;
33        debug_assert!(capacity <= u32::MAX as usize);
34        Self {
35            _marker: PhantomData,
36            buffer: Vec::with_capacity(capacity),
37        }
38    }
39}
40
41#[async_trait]
42impl<P: StoreParams> RequestResponseCodec for BitswapCodec<P> {
43    type Protocol = BitswapProtocol;
44    type Request = BitswapRequest;
45    type Response = BitswapResponse;
46
47    async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Self::Request>
48    where
49        T: AsyncRead + Send + Unpin,
50    {
51        let msg_len = u32_to_usize(aio::read_u32(&mut *io).await.map_err(|e| match e {
52            ReadError::Io(e) => e,
53            err => other(err),
54        })?);
55        if msg_len > MAX_CID_SIZE + 1 {
56            return Err(invalid_data(MessageTooLarge(msg_len)));
57        }
58        self.buffer.resize(msg_len, 0);
59        io.read_exact(&mut self.buffer).await?;
60        let request = BitswapRequest::from_bytes(&self.buffer).map_err(invalid_data)?;
61        Ok(request)
62    }
63
64    async fn read_response<T>(
65        &mut self,
66        _: &Self::Protocol,
67        io: &mut T,
68    ) -> io::Result<Self::Response>
69    where
70        T: AsyncRead + Send + Unpin,
71    {
72        let msg_len = u32_to_usize(aio::read_u32(&mut *io).await.map_err(|e| match e {
73            ReadError::Io(e) => e,
74            err => other(err),
75        })?);
76        if msg_len > P::MAX_BLOCK_SIZE + 1 {
77            return Err(invalid_data(MessageTooLarge(msg_len)));
78        }
79        self.buffer.resize(msg_len, 0);
80        io.read_exact(&mut self.buffer).await?;
81        let response = BitswapResponse::from_bytes(&self.buffer).map_err(invalid_data)?;
82        Ok(response)
83    }
84
85    async fn write_request<T>(
86        &mut self,
87        _: &Self::Protocol,
88        io: &mut T,
89        req: Self::Request,
90    ) -> io::Result<()>
91    where
92        T: AsyncWrite + Send + Unpin,
93    {
94        self.buffer.clear();
95        req.write_to(&mut self.buffer)?;
96        if self.buffer.len() > MAX_CID_SIZE + 1 {
97            return Err(invalid_data(MessageTooLarge(self.buffer.len())));
98        }
99        let mut buf = unsigned_varint::encode::u32_buffer();
100        let msg_len = unsigned_varint::encode::u32(self.buffer.len() as u32, &mut buf);
101        io.write_all(msg_len).await?;
102        io.write_all(&self.buffer).await?;
103        Ok(())
104    }
105
106    async fn write_response<T>(
107        &mut self,
108        _: &Self::Protocol,
109        io: &mut T,
110        res: Self::Response,
111    ) -> io::Result<()>
112    where
113        T: AsyncWrite + Send + Unpin,
114    {
115        self.buffer.clear();
116        res.write_to(&mut self.buffer)?;
117        if self.buffer.len() > P::MAX_BLOCK_SIZE + 1 {
118            return Err(invalid_data(MessageTooLarge(self.buffer.len())));
119        }
120        let mut buf = unsigned_varint::encode::u32_buffer();
121        let msg_len = unsigned_varint::encode::u32(self.buffer.len() as u32, &mut buf);
122        io.write_all(msg_len).await?;
123        io.write_all(&self.buffer).await?;
124        Ok(())
125    }
126}
127
128#[derive(Clone, Copy, Debug, Eq, PartialEq)]
129pub enum RequestType {
130    Have,
131    Block,
132}
133
134/// A request sent to another peer.
135#[derive(Clone, Copy, Debug, Eq, PartialEq)]
136pub struct BitswapRequest {
137    /// type of request: have or block
138    pub ty: RequestType,
139    /// CID the request is for
140    pub cid: Cid,
141}
142
143impl BitswapRequest {
144    /// write binary representation of the request
145    pub fn write_to<W: Write>(&self, w: &mut W) -> io::Result<()> {
146        match self {
147            BitswapRequest {
148                ty: RequestType::Have,
149                cid,
150            } => {
151                w.write_all(&[0])?;
152                cid.write_bytes(&mut *w).map_err(other)?;
153            }
154            BitswapRequest {
155                ty: RequestType::Block,
156                cid,
157            } => {
158                w.write_all(&[1])?;
159                cid.write_bytes(&mut *w).map_err(other)?;
160            }
161        }
162        Ok(())
163    }
164
165    /// read back binary representation of the request
166    pub fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
167        let ty = match bytes[0] {
168            0 => RequestType::Have,
169            1 => RequestType::Block,
170            c => return Err(invalid_data(UnknownMessageType(c))),
171        };
172        let cid = Cid::try_from(&bytes[1..]).map_err(invalid_data)?;
173        Ok(Self { ty, cid })
174    }
175}
176
177/// Response to a [BitswapRequest]
178#[derive(Clone, Debug, Eq, PartialEq)]
179pub enum BitswapResponse {
180    /// block presence
181    Have(bool),
182    /// block bytes
183    Block(Vec<u8>),
184}
185
186impl BitswapResponse {
187    /// write binary representation of the request
188    pub fn write_to<W: Write>(&self, w: &mut W) -> io::Result<()> {
189        match self {
190            BitswapResponse::Have(have) => {
191                if *have {
192                    w.write_all(&[0])?;
193                } else {
194                    w.write_all(&[2])?;
195                }
196            }
197            BitswapResponse::Block(data) => {
198                w.write_all(&[1])?;
199                w.write_all(data)?;
200            }
201        };
202        Ok(())
203    }
204
205    /// read back binary representation of the request
206    pub fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
207        let res = match bytes[0] {
208            0 | 2 => BitswapResponse::Have(bytes[0] == 0),
209            1 => BitswapResponse::Block(bytes[1..].to_vec()),
210            c => return Err(invalid_data(UnknownMessageType(c))),
211        };
212        Ok(res)
213    }
214}
215
216fn invalid_data<E: std::error::Error + Send + Sync + 'static>(e: E) -> io::Error {
217    io::Error::new(io::ErrorKind::InvalidData, e)
218}
219
220fn other<E: std::error::Error + Send + Sync + 'static>(e: E) -> io::Error {
221    io::Error::new(io::ErrorKind::Other, e)
222}
223
224#[cfg(any(target_pointer_width = "64", target_pointer_width = "32"))]
225fn u32_to_usize(n: u32) -> usize {
226    n as usize
227}
228
229#[derive(Debug, Error)]
230#[error("unknown message type {0}")]
231pub struct UnknownMessageType(u8);
232
233#[derive(Debug, Error)]
234#[error("message too large {0}")]
235pub struct MessageTooLarge(usize);
236
237#[cfg(test)]
238pub(crate) mod tests {
239    use super::*;
240    use libipld::multihash::Code;
241    use multihash::MultihashDigest;
242
243    pub fn create_cid(bytes: &[u8]) -> Cid {
244        let digest = Code::Blake3_256.digest(bytes);
245        Cid::new_v1(0x55, digest)
246    }
247
248    #[test]
249    fn test_request_encode_decode() {
250        let requests = [
251            BitswapRequest {
252                ty: RequestType::Have,
253                cid: create_cid(&b"have_request"[..]),
254            },
255            BitswapRequest {
256                ty: RequestType::Block,
257                cid: create_cid(&b"block_request"[..]),
258            },
259        ];
260        let mut buf = Vec::with_capacity(MAX_CID_SIZE + 1);
261        for request in &requests {
262            buf.clear();
263            request.write_to(&mut buf).unwrap();
264            assert_eq!(&BitswapRequest::from_bytes(&buf).unwrap(), request);
265        }
266    }
267
268    #[test]
269    fn test_response_encode_decode() {
270        let responses = [
271            BitswapResponse::Have(true),
272            BitswapResponse::Have(false),
273            BitswapResponse::Block(b"block_response".to_vec()),
274        ];
275        let mut buf = Vec::with_capacity(13 + 1);
276        for response in &responses {
277            buf.clear();
278            response.write_to(&mut buf).unwrap();
279            assert_eq!(&BitswapResponse::from_bytes(&buf).unwrap(), response);
280        }
281    }
282}