mwc_libp2p_core/upgrade/transfer.rs
1// Copyright 2019 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! Contains some helper futures for creating upgrades.
22
23use futures::prelude::*;
24use std::{error, fmt, io};
25
26// TODO: these methods could be on an Ext trait to AsyncWrite
27
28/// Send a message to the given socket, then shuts down the writing side.
29///
30/// > **Note**: Prepends a variable-length prefix indicate the length of the message. This is
31/// > compatible with what `read_one` expects.
32pub async fn write_one(socket: &mut (impl AsyncWrite + Unpin), data: impl AsRef<[u8]>)
33 -> Result<(), io::Error>
34{
35 write_varint(socket, data.as_ref().len()).await?;
36 socket.write_all(data.as_ref()).await?;
37 socket.close().await?;
38 Ok(())
39}
40
41/// Send a message to the given socket with a length prefix appended to it. Also flushes the socket.
42///
43/// > **Note**: Prepends a variable-length prefix indicate the length of the message. This is
44/// > compatible with what `read_one` expects.
45pub async fn write_with_len_prefix(socket: &mut (impl AsyncWrite + Unpin), data: impl AsRef<[u8]>)
46 -> Result<(), io::Error>
47{
48 write_varint(socket, data.as_ref().len()).await?;
49 socket.write_all(data.as_ref()).await?;
50 socket.flush().await?;
51 Ok(())
52}
53
54/// Writes a variable-length integer to the `socket`.
55///
56/// > **Note**: Does **NOT** flush the socket.
57pub async fn write_varint(socket: &mut (impl AsyncWrite + Unpin), len: usize)
58 -> Result<(), io::Error>
59{
60 let mut len_data = unsigned_varint::encode::usize_buffer();
61 let encoded_len = unsigned_varint::encode::usize(len, &mut len_data).len();
62 socket.write_all(&len_data[..encoded_len]).await?;
63 Ok(())
64}
65
66/// Reads a variable-length integer from the `socket`.
67///
68/// As a special exception, if the `socket` is empty and EOFs right at the beginning, then we
69/// return `Ok(0)`.
70///
71/// > **Note**: This function reads bytes one by one from the `socket`. It is therefore encouraged
72/// > to use some sort of buffering mechanism.
73pub async fn read_varint(socket: &mut (impl AsyncRead + Unpin)) -> Result<usize, io::Error> {
74 let mut buffer = unsigned_varint::encode::usize_buffer();
75 let mut buffer_len = 0;
76
77 loop {
78 match socket.read(&mut buffer[buffer_len..buffer_len+1]).await? {
79 0 => {
80 // Reaching EOF before finishing to read the length is an error, unless the EOF is
81 // at the very beginning of the substream, in which case we assume that the data is
82 // empty.
83 if buffer_len == 0 {
84 return Ok(0);
85 } else {
86 return Err(io::ErrorKind::UnexpectedEof.into());
87 }
88 }
89 n => debug_assert_eq!(n, 1),
90 }
91
92 buffer_len += 1;
93
94 match unsigned_varint::decode::usize(&buffer[..buffer_len]) {
95 Ok((len, _)) => return Ok(len),
96 Err(unsigned_varint::decode::Error::Overflow) => {
97 return Err(io::Error::new(
98 io::ErrorKind::InvalidData,
99 "overflow in variable-length integer"
100 ));
101 }
102 // TODO: why do we have a `__Nonexhaustive` variant in the error? I don't know how to process it
103 // Err(unsigned_varint::decode::Error::Insufficient) => {}
104 Err(_) => {}
105 }
106 }
107}
108
109/// Reads a length-prefixed message from the given socket.
110///
111/// The `max_size` parameter is the maximum size in bytes of the message that we accept. This is
112/// necessary in order to avoid DoS attacks where the remote sends us a message of several
113/// gigabytes.
114///
115/// > **Note**: Assumes that a variable-length prefix indicates the length of the message. This is
116/// > compatible with what `write_one` does.
117pub async fn read_one(socket: &mut (impl AsyncRead + Unpin), max_size: usize)
118 -> Result<Vec<u8>, ReadOneError>
119{
120 let len = read_varint(socket).await?;
121 if len > max_size {
122 return Err(ReadOneError::TooLarge {
123 requested: len,
124 max: max_size,
125 });
126 }
127
128 let mut buf = vec![0; len];
129 socket.read_exact(&mut buf).await?;
130 Ok(buf)
131}
132
133/// Error while reading one message.
134#[derive(Debug)]
135pub enum ReadOneError {
136 /// Error on the socket.
137 Io(std::io::Error),
138 /// Requested data is over the maximum allowed size.
139 TooLarge {
140 /// Size requested by the remote.
141 requested: usize,
142 /// Maximum allowed.
143 max: usize,
144 },
145}
146
147impl From<std::io::Error> for ReadOneError {
148 fn from(err: std::io::Error) -> ReadOneError {
149 ReadOneError::Io(err)
150 }
151}
152
153impl fmt::Display for ReadOneError {
154 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
155 match *self {
156 ReadOneError::Io(ref err) => write!(f, "{}", err),
157 ReadOneError::TooLarge { .. } => write!(f, "Received data size over maximum"),
158 }
159 }
160}
161
162impl error::Error for ReadOneError {
163 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
164 match *self {
165 ReadOneError::Io(ref err) => Some(err),
166 ReadOneError::TooLarge { .. } => None,
167 }
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174
175 #[test]
176 fn write_one_works() {
177 let data = (0..rand::random::<usize>() % 10_000)
178 .map(|_| rand::random::<u8>())
179 .collect::<Vec<_>>();
180
181 let mut out = vec![0; 10_000];
182 futures::executor::block_on(
183 write_one(&mut futures::io::Cursor::new(&mut out[..]), data.clone())
184 ).unwrap();
185
186 let (out_len, out_data) = unsigned_varint::decode::usize(&out).unwrap();
187 assert_eq!(out_len, data.len());
188 assert_eq!(&out_data[..out_len], &data[..]);
189 }
190
191 // TODO: rewrite these tests
192/*
193 #[test]
194 fn read_one_works() {
195 let original_data = (0..rand::random::<usize>() % 10_000)
196 .map(|_| rand::random::<u8>())
197 .collect::<Vec<_>>();
198
199 let mut len_buf = unsigned_varint::encode::usize_buffer();
200 let len_buf = unsigned_varint::encode::usize(original_data.len(), &mut len_buf);
201
202 let mut in_buffer = len_buf.to_vec();
203 in_buffer.extend_from_slice(&original_data);
204
205 let future = read_one_then(Cursor::new(in_buffer), 10_000, (), move |out, ()| -> Result<_, ReadOneError> {
206 assert_eq!(out, original_data);
207 Ok(())
208 });
209
210 futures::executor::block_on(future).unwrap();
211 }
212
213 #[test]
214 fn read_one_zero_len() {
215 let future = read_one_then(Cursor::new(vec![0]), 10_000, (), move |out, ()| -> Result<_, ReadOneError> {
216 assert!(out.is_empty());
217 Ok(())
218 });
219
220 futures::executor::block_on(future).unwrap();
221 }
222
223 #[test]
224 fn read_checks_length() {
225 let mut len_buf = unsigned_varint::encode::u64_buffer();
226 let len_buf = unsigned_varint::encode::u64(5_000, &mut len_buf);
227
228 let mut in_buffer = len_buf.to_vec();
229 in_buffer.extend((0..5000).map(|_| 0));
230
231 let future = read_one_then(Cursor::new(in_buffer), 100, (), move |_, ()| -> Result<_, ReadOneError> {
232 Ok(())
233 });
234
235 match futures::executor::block_on(future) {
236 Err(ReadOneError::TooLarge { .. }) => (),
237 _ => panic!(),
238 }
239 }
240
241 #[test]
242 fn read_one_accepts_empty() {
243 let future = read_one_then(Cursor::new([]), 10_000, (), move |out, ()| -> Result<_, ReadOneError> {
244 assert!(out.is_empty());
245 Ok(())
246 });
247
248 futures::executor::block_on(future).unwrap();
249 }
250
251 #[test]
252 fn read_one_eof_before_len() {
253 let future = read_one_then(Cursor::new([0x80]), 10_000, (), move |_, ()| -> Result<(), ReadOneError> {
254 unreachable!()
255 });
256
257 match futures::executor::block_on(future) {
258 Err(ReadOneError::Io(ref err)) if err.kind() == io::ErrorKind::UnexpectedEof => (),
259 _ => panic!()
260 }
261 }*/
262}