1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
use crate::Error;
use futures_io::AsyncRead;
use futures_util::AsyncReadExt;
use minicbor::Decode;
use std::io;
#[derive(Debug)]
pub struct AsyncReader<R> {
reader: R,
buffer: Vec<u8>,
max_len: usize,
state: State
}
#[derive(Debug)]
enum State {
ReadLen([u8; 4], u8),
ReadVal(usize)
}
impl State {
fn new() -> Self {
State::ReadLen([0; 4], 0)
}
}
impl<R> AsyncReader<R> {
pub fn new(reader: R) -> Self {
Self::with_buffer(reader, Vec::new())
}
pub fn with_buffer(reader: R, buffer: Vec<u8>) -> Self {
Self { reader, buffer, max_len: 512 * 1024, state: State::new() }
}
pub fn set_max_len(&mut self, val: u32) {
self.max_len = val as usize
}
pub fn reader(&self) -> &R {
&self.reader
}
pub fn reader_mut(&mut self) -> &mut R {
&mut self.reader
}
pub fn into_parts(self) -> (R, Vec<u8>) {
(self.reader, self.buffer)
}
}
impl<R: AsyncRead + Unpin> AsyncReader<R> {
pub async fn read<'a, T: Decode<'a>>(&'a mut self) -> Result<Option<T>, Error> {
loop {
match self.state {
State::ReadLen(buf, 4) => {
let len = u32::from_be_bytes(buf) as usize;
if len > self.max_len {
return Err(Error::InvalidLen)
}
self.buffer.clear();
self.buffer.resize(len, 0u8);
self.state = State::ReadVal(0)
}
State::ReadLen(ref mut buf, ref mut o) => {
let n = self.reader.read(&mut buf[usize::from(*o) ..]).await?;
if n == 0 {
return if *o == 0 {
Ok(None)
} else {
Err(Error::Io(io::ErrorKind::UnexpectedEof.into()))
}
}
*o += n as u8
}
State::ReadVal(o) if o >= self.buffer.len() => {
self.state = State::new();
return minicbor::decode(&self.buffer).map_err(Error::Decode).map(Some)
}
State::ReadVal(ref mut o) => {
let n = self.reader.read(&mut self.buffer[*o ..]).await?;
if n == 0 {
return Err(Error::Io(io::ErrorKind::UnexpectedEof.into()))
}
*o += n
}
}
}
}
}