libp2prs_traits/
lib.rs

1// Copyright 2020 Netwarps Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21mod copy;
22
23use std::io;
24use std::io::ErrorKind;
25
26use async_trait::async_trait;
27use futures::io::{ReadHalf, WriteHalf};
28use futures::prelude::*;
29use futures::{AsyncReadExt, AsyncWriteExt};
30
31pub use copy::copy;
32
33/// Read Trait for async/await
34///
35#[async_trait]
36pub trait ReadEx: Send {
37    /// Reads some bytes from the byte stream.
38    ///
39    /// On success, returns the total number of bytes read.
40    ///
41    /// If the return value is `Ok(n)`, then it must be guaranteed that
42    /// `0 <= n <= buf.len()`. A nonzero `n` value indicates that the buffer has been
43    /// filled with `n` bytes of data. If `n` is `0`, then it can indicate one of two
44    /// scenarios:
45    ///
46    /// 1. This reader has reached its "end of file" and will likely no longer be able to
47    ///    produce bytes. Note that this does not mean that the reader will always no
48    ///    longer be able to produce bytes.
49    /// 2. The buffer specified was 0 bytes in length.
50    ///
51    /// Attempt to read bytes from underlying stream object.
52    ///
53    /// On success, returns `Ok(n)`.
54    /// Otherwise, returns `Err(io:Error)`
55    async fn read2(&mut self, buf: &mut [u8]) -> Result<usize, io::Error>;
56
57    /// Reads the exact number of bytes requested.
58    ///
59    /// On success, returns `Ok(())`.
60    /// Otherwise, returns `Err(io:Error)`.
61    async fn read_exact2<'a>(&'a mut self, buf: &'a mut [u8]) -> Result<(), io::Error> {
62        let mut buf_piece = buf;
63        while !buf_piece.is_empty() {
64            let n = self.read2(buf_piece).await?;
65            if n == 0 {
66                return Err(ErrorKind::UnexpectedEof.into());
67            }
68
69            let (_, rest) = buf_piece.split_at_mut(n);
70            buf_piece = rest;
71        }
72        Ok(())
73    }
74
75    /// Reads a fixed-length integer from the underlying IO.
76    ///
77    /// On success, return `Ok(n)`.
78    /// Otherwise, returns `Err(io:Error)`.
79    async fn read_fixed_u32(&mut self) -> Result<usize, io::Error> {
80        let mut len = [0; 4];
81        self.read_exact2(&mut len).await?;
82        let n = u32::from_be_bytes(len) as usize;
83
84        Ok(n)
85    }
86
87    /// Reads a variable-length integer from the underlying IO.
88    ///
89    /// As a special exception, if the `IO` is empty and EOFs right at the beginning, then we
90    /// return `Ok(0)`.
91    ///
92    /// On success, return `Ok(n)`.
93    /// Otherwise, returns `Err(io:Error)`.
94    ///
95    /// > **Note**: This function reads bytes one by one from the underlying IO. It is therefore encouraged
96    /// >           to use some sort of buffering mechanism.
97    async fn read_varint(&mut self) -> Result<usize, io::Error> {
98        let mut buffer = unsigned_varint::encode::usize_buffer();
99        let mut buffer_len = 0;
100
101        loop {
102            match self.read2(&mut buffer[buffer_len..=buffer_len]).await? {
103                0 => {
104                    // Reaching EOF before finishing to read the length is an error, unless the EOF is
105                    // at the very beginning of the substream, in which case we assume that the data is
106                    // empty.
107                    if buffer_len == 0 {
108                        return Ok(0);
109                    } else {
110                        return Err(io::ErrorKind::UnexpectedEof.into());
111                    }
112                }
113                n => debug_assert_eq!(n, 1),
114            }
115
116            buffer_len += 1;
117
118            match unsigned_varint::decode::usize(&buffer[..buffer_len]) {
119                Ok((len, _)) => return Ok(len),
120                Err(unsigned_varint::decode::Error::Overflow) => {
121                    return Err(io::Error::new(io::ErrorKind::InvalidData, "overflow in variable-length integer"));
122                }
123                // TODO: why do we have a `__Nonexhaustive` variant in the error? I don't know how to process it
124                // Err(unsigned_varint::decode::Error::Insufficient) => {}
125                Err(_) => {}
126            }
127        }
128    }
129
130    /// Reads a fixed length-prefixed message from the underlying IO.
131    ///
132    /// The `max_size` parameter is the maximum size in bytes of the message that we accept. This is
133    /// necessary in order to avoid DoS attacks where the remote sends us a message of several
134    /// gigabytes.
135    ///
136    /// > **Note**: Assumes that a fixed-length prefix indicates the length of the message. This is
137    /// >           compatible with what `write_one_fixed` does.
138    async fn read_one_fixed(&mut self, max_size: usize) -> Result<Vec<u8>, io::Error> {
139        let len = self.read_fixed_u32().await?;
140        if len > max_size {
141            return Err(io::Error::new(
142                io::ErrorKind::InvalidData,
143                format!("Received data size over maximum frame length: {}>{}", len, max_size),
144            ));
145        }
146
147        let mut buf = vec![0; len];
148        self.read_exact2(&mut buf).await?;
149        Ok(buf)
150    }
151
152    /// Reads a variable length-prefixed message from the underlying IO.
153    ///
154    /// The `max_size` parameter is the maximum size in bytes of the message that we accept. This is
155    /// necessary in order to avoid DoS attacks where the remote sends us a message of several
156    /// gigabytes.
157    ///
158    /// On success, returns `Ok(Vec<u8>)`.
159    /// Otherwise, returns `Err(io:Error)`.
160    ///
161    /// > **Note**: Assumes that a variable-length prefix indicates the length of the message. This is
162    /// >           compatible with what `write_one` does.
163    async fn read_one(&mut self, max_size: usize) -> Result<Vec<u8>, io::Error> {
164        let len = self.read_varint().await?;
165        if len > max_size {
166            return Err(io::Error::new(
167                io::ErrorKind::InvalidData,
168                format!("Received data size over maximum frame length: {}>{}", len, max_size),
169            ));
170        }
171
172        let mut buf = vec![0; len];
173        self.read_exact2(&mut buf).await?;
174        Ok(buf)
175    }
176}
177
178/// Write Trait for async/await
179///
180#[async_trait]
181pub trait WriteEx: Send {
182    /// Attempt to write bytes from `buf` into the object.
183    ///
184    /// On success, returns `Ok(num_bytes_written)`.
185    /// Otherwise, returns `Err(io:Error)`
186    async fn write2(&mut self, buf: &[u8]) -> Result<usize, io::Error>;
187    /// Attempt to write the entire contents of data into object.
188    ///
189    /// The operation will not complete until all the data has been written.
190    ///
191    /// On success, returns `Ok(())`.
192    /// Otherwise, returns `Err(io:Error)`
193    async fn write_all2(&mut self, buf: &[u8]) -> Result<(), io::Error> {
194        let mut buf_piece = buf;
195        while !buf_piece.is_empty() {
196            let n = self.write2(buf_piece).await?;
197            if n == 0 {
198                return Err(io::ErrorKind::WriteZero.into());
199            }
200
201            let (_, rest) = buf_piece.split_at(n);
202            buf_piece = rest;
203        }
204        Ok(())
205    }
206
207    /// Writes a variable-length integer to the underlying IO.
208    ///
209    /// On success, returns `Ok(())`.
210    /// Otherwise, returns `Err(io:Error)`
211    ///
212    /// > **Note**: Does **NOT** flush the IO.
213    async fn write_varint(&mut self, len: usize) -> Result<(), io::Error> {
214        let mut len_data = unsigned_varint::encode::usize_buffer();
215        let encoded_len = unsigned_varint::encode::usize(len, &mut len_data).len();
216        self.write_all2(&len_data[..encoded_len]).await?;
217        Ok(())
218    }
219
220    /// Writes a fixed-length integer to the underlying IO.
221    ///
222    /// On success, returns `Ok(())`.
223    /// Otherwise, returns `Err(io:Error)`
224    ///
225    /// > **Note**: Does **NOT** flush the IO.
226    async fn write_fixed_u32(&mut self, len: usize) -> Result<(), io::Error> {
227        self.write_all2((len as u32).to_be_bytes().as_ref()).await?;
228        Ok(())
229    }
230
231    /// Send a fixed length message to the underlying IO, then flushes the writing side.
232    ///
233    /// > **Note**: Prepends a fixed-length prefix indicate the length of the message. This is
234    /// >           compatible with what `read_one_fixed` expects.
235    async fn write_one_fixed(&mut self, buf: &[u8]) -> Result<(), io::Error> {
236        self.write_fixed_u32(buf.len()).await?;
237        self.write_all2(buf).await?;
238        self.flush2().await?;
239        Ok(())
240    }
241
242    /// Send a variable length message to the underlying IO, then flushes the writing side.
243    ///
244    /// On success, returns `Ok(())`.
245    /// Otherwise, returns `Err(io:Error)`
246    ///
247    /// > **Note**: Prepends a variable-length prefix indicate the length of the message. This is
248    /// >           compatible with what `read_one` expects.
249    async fn write_one(&mut self, buf: &[u8]) -> Result<(), io::Error> {
250        self.write_varint(buf.len()).await?;
251        self.write_all2(buf).await?;
252        self.flush2().await?;
253        Ok(())
254    }
255
256    /// Attempt to flush the object, ensuring that any buffered data reach
257    /// their destination.
258    ///
259    /// On success, returns `Ok(())`.
260    /// Otherwise, returns `Err(io:Error)`
261    async fn flush2(&mut self) -> Result<(), io::Error>;
262
263    /// Attempt to close the object.
264    ///
265    /// On success, returns `Ok(())`.
266    /// Otherwise, returns `Err(io:Error)`
267    async fn close2(&mut self) -> Result<(), io::Error>;
268}
269
270#[async_trait]
271impl<T: AsyncRead + Unpin + Send> ReadEx for T {
272    async fn read2(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> {
273        let n = AsyncReadExt::read(self, buf).await?;
274        Ok(n)
275    }
276}
277
278#[async_trait]
279impl<T: AsyncWrite + Unpin + Send> WriteEx for T {
280    async fn write2(&mut self, buf: &[u8]) -> Result<usize, io::Error> {
281        AsyncWriteExt::write(self, buf).await
282    }
283
284    async fn flush2(&mut self) -> Result<(), io::Error> {
285        AsyncWriteExt::flush(self).await
286    }
287
288    async fn close2(&mut self) -> Result<(), io::Error> {
289        AsyncWriteExt::close(self).await
290    }
291}
292
293///
294/// SplitEx Trait for read and write separation
295///
296pub trait SplitEx {
297    /// read half
298    type Reader: ReadEx + Unpin;
299    /// write half
300    type Writer: WriteEx + Unpin;
301
302    /// split Self to independent reader and writer
303    fn split(self) -> (Self::Reader, Self::Writer);
304}
305
306// a common way to support SplitEx for T, requires T: AsyncRead+AsyncWrite
307impl<T: AsyncRead + AsyncWrite + Send + Unpin> SplitEx for T {
308    type Reader = ReadHalf<T>;
309    type Writer = WriteHalf<T>;
310
311    fn split(self) -> (Self::Reader, Self::Writer) {
312        futures::AsyncReadExt::split(self)
313    }
314}
315
316/// SplittableReadWrite trait for simplifying generic type constraints
317pub trait SplittableReadWrite: ReadEx + WriteEx + SplitEx + Unpin + 'static {}
318
319impl<T: ReadEx + WriteEx + SplitEx + Unpin + 'static> SplittableReadWrite for T {}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use futures::io::{self, AsyncReadExt, Cursor};
325    use libp2prs_runtime::task;
326
327    struct Test(Cursor<Vec<u8>>);
328
329    #[async_trait]
330    impl ReadEx for Test {
331        async fn read2(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> {
332            self.0.read(buf).await
333        }
334    }
335
336    #[async_trait]
337    impl WriteEx for Test {
338        async fn write2(&mut self, buf: &[u8]) -> Result<usize, io::Error> {
339            self.0.write(buf).await
340        }
341
342        async fn flush2(&mut self) -> Result<(), io::Error> {
343            self.0.flush().await
344        }
345
346        async fn close2(&mut self) -> Result<(), io::Error> {
347            self.0.close().await
348        }
349    }
350
351    /// Read Vec<u8>
352    #[test]
353    fn test_read() {
354        task::block_on(async {
355            let mut reader = Test(Cursor::new(vec![1, 2, 3, 4]));
356            let mut output = [0u8; 3];
357            let bytes = reader.read2(&mut output[..]).await.unwrap();
358
359            assert_eq!(bytes, 3);
360            assert_eq!(output, [1, 2, 3]);
361        });
362    }
363
364    // Read string
365    #[test]
366    fn test_read_string() {
367        task::block_on(async {
368            let mut reader = Test(Cursor::new(b"hello world".to_vec()));
369            let mut output = [0u8; 3];
370            let bytes = reader.read2(&mut output[..]).await.unwrap();
371
372            assert_eq!(bytes, 3);
373            assert_eq!(output, [104, 101, 108]);
374        });
375    }
376
377    #[test]
378    fn test_read_exact() {
379        task::block_on(async {
380            let mut reader = Test(Cursor::new(vec![1, 2, 3, 4]));
381            let mut output = [0u8; 3];
382            let _bytes = reader.read_exact2(&mut output[..]).await;
383
384            assert_eq!(output, [1, 2, 3]);
385        });
386    }
387
388    #[test]
389    fn test_read_fixed_u32() {
390        task::block_on(async {
391            let mut reader = Test(Cursor::new(b"hello world".to_vec()));
392            let size = reader.read_fixed_u32().await.unwrap();
393
394            assert_eq!(size, 1751477356);
395        });
396    }
397
398    #[test]
399    fn test_read_varint() {
400        task::block_on(async {
401            let mut reader = Test(Cursor::new(vec![1, 2, 3, 4, 5, 6]));
402            let size = reader.read_varint().await.unwrap();
403
404            assert_eq!(size, 1);
405        });
406    }
407
408    #[test]
409    fn test_read_one() {
410        task::block_on(async {
411            let mut reader = Test(Cursor::new(vec![11, 104, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100]));
412            let output = match reader.read_one(11).await {
413                Ok(v) => v,
414                _ => Vec::new(),
415            };
416
417            assert_eq!(output, b"hello world");
418        });
419    }
420
421    #[test]
422    fn test_write() {
423        task::block_on(async {
424            let mut writer = Test(Cursor::new(vec![0u8; 5]));
425            let size = writer.write2(&[1, 2, 3, 4]).await.unwrap();
426
427            assert_eq!(size, 4);
428            assert_eq!(writer.0.get_mut(), &[1, 2, 3, 4, 0])
429        });
430    }
431
432    #[test]
433    fn test_write_all2() {
434        task::block_on(async {
435            let mut writer = Test(Cursor::new(vec![0u8; 4]));
436            let output = vec![1, 2, 3, 4, 5];
437            let _bytes = writer.write_all2(&output[..]).await.unwrap();
438
439            assert_eq!(writer.0.get_mut(), &[1, 2, 3, 4, 5]);
440        });
441    }
442
443    #[test]
444    fn test_write_fixed_u32() {
445        task::block_on(async {
446            let mut writer = Test(Cursor::new(b"hello world".to_vec()));
447            let _result = writer.write_fixed_u32(1751477356).await.unwrap();
448
449            // Binary value of `hell` is 17751477356, if write successfully, current
450            // pointer ought to stay on 4
451            assert_eq!(writer.0.position(), 4);
452        });
453    }
454
455    #[test]
456    fn test_write_varint() {
457        task::block_on(async {
458            let mut writer = Test(Cursor::new(vec![2, 2, 3, 4, 5, 6]));
459            let _result = writer.write_varint(1).await.unwrap();
460
461            assert_eq!(writer.0.position(), 1);
462        });
463    }
464
465    #[test]
466    fn test_write_one() {
467        task::block_on(async {
468            let mut writer = Test(Cursor::new(vec![0u8; 0]));
469            let _result = writer.write_one("hello world".as_ref()).await;
470            assert_eq!(writer.0.get_mut(), &[11, 104, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100]);
471        });
472    }
473}