1use async_std::prelude::*;
2use async_std::io::{Read, Write};
3use crate::{Error};
4
5pub fn vec_has_sequence(bytes: &[u8], needle: &[u8]) -> bool {
7 let mut found = 0;
8 let nsize = needle.len();
9 for byte in bytes.into_iter() {
10 if *byte == needle[found] {
11 found += 1;
12 } else {
13 found = 0;
14 }
15 if found == nsize {
16 return true;
17 }
18 }
19 false
20}
21
22pub async fn read_protocol_lines<I>(input: &mut I, lines: &mut Vec<String>, limit: Option<usize>) -> Result<usize, Error>
25 where
26 I: Read + Unpin,
27{
28 let mut buffer: Vec<u8> = Vec::new();
29 let mut stage = 0; let mut count = 0; loop {
33 let mut byte = [0u8];
34 let size = match input.read(&mut byte).await {
35 Ok(size) => size,
36 Err(_) => return Err(Error::StreamNotReadable),
37 };
38 let byte = byte[0];
39 count += 1;
40
41 if size == 0 { break;
43 } else if limit.is_some() && Some(count) >= limit {
44 return Err(Error::SizeLimitExceeded(limit.unwrap()));
45 } else if byte == 0x0D { if stage == 0 || stage == 2 {
47 stage += 1;
48 } else {
49 return Err(Error::InvalidData);
50 }
51 } else if byte == 0x0A { if stage == 1 || stage == 3 {
53 let line = match String::from_utf8(buffer.to_vec()) {
54 Ok(line) => line,
55 Err(_) => return Err(Error::InvalidData),
56 };
57 if stage == 3 {
58 break; } else {
60 lines.push(line);
61 buffer.clear();
62 stage += 1;
63 }
64 } else {
65 return Err(Error::InvalidData);
66 }
67 } else { buffer.push(byte);
69 stage = 0;
70 }
71 }
72
73 Ok(count)
74}
75
76pub async fn read_chunked_stream<I>(stream: &mut I, source: &mut Vec<u8>, limit: Option<usize>) -> Result<usize, Error>
83 where
84 I: Read + Unpin,
85{
86 let mut buffer: Vec<u8> = Vec::new();
87 let mut stage = 0; let mut count = 0; loop {
91 let mut byte = [0u8];
92 let size = match stream.read(&mut byte).await {
93 Ok(size) => size,
94 Err(_) => return Err(Error::StreamNotReadable),
95 };
96 let byte = byte[0];
97
98 if size == 0 { break;
100 } else if byte == 0x0D { if stage == 0 || stage == 2 {
102 stage += 1;
103 } else {
104 return Err(Error::InvalidData);
105 }
106 } else if byte == 0x0A { if stage == 1 || stage == 3 {
108 if stage == 3 {
109 break; } else {
111 let length = match String::from_utf8(buffer.to_vec()) {
112 Ok(length) => match i64::from_str_radix(&length, 16) {
113 Ok(length) => length as usize,
114 Err(_) => return Err(Error::InvalidData),
115 },
116 Err(_) => return Err(Error::InvalidData),
117 };
118 if length == 0 {
119 break;
120 } else if limit.is_some() && count + length > limit.unwrap() {
121 return Err(Error::SizeLimitExceeded(limit.unwrap()));
122 } else {
123 read_sized_stream(stream, source, length).await?;
124 read_sized_stream(stream, &mut Vec::new(), 2).await?;
125 count += length;
126 }
127 buffer.clear();
128 stage = 0;
129 }
130 } else {
131 return Err(Error::InvalidData);
132 }
133 } else { buffer.push(byte);
135 }
136 }
137
138 Ok(count)
139}
140
141pub async fn read_sized_stream<I>(stream: &mut I, source: &mut Vec<u8>, length: usize) -> Result<usize, Error>
142 where
143 I: Read + Unpin,
144{
145 let mut bytes = vec![0u8; length];
146 match stream.read_exact(&mut bytes).await {
147 Ok(size) => size,
148 Err(_) => return Err(Error::StreamNotReadable),
149 };
150
151 source.append(&mut bytes);
152
153 Ok(length)
154}
155
156pub async fn relay_chunked_stream<I, O>(input: &mut I, output: &mut O, limit: Option<usize>) -> Result<usize, Error>
163 where
164 I: Write + Read + Unpin,
165 O: Write + Read + Unpin,
166{
167 let mut buffer: Vec<u8> = Vec::new();
168 let mut count = 0;
169 loop {
170 if limit.is_some() && count >= limit.unwrap() {
171 return Err(Error::SizeLimitExceeded(limit.unwrap()));
172 }
173
174 let mut bytes = [0u8; 1024];
175 let size = match input.read(&mut bytes).await {
176 Ok(size) => size,
177 Err(_) => return Err(Error::StreamNotReadable),
178 };
179 let mut bytes = &mut bytes[0..size].to_vec();
180 count += size;
181
182 write_to_stream(output, &bytes).await?;
183 flush_stream(output).await?;
184
185 buffer.append(&mut bytes);
186 buffer = (&buffer[buffer.len()-5..]).to_vec();
187 if vec_has_sequence(&buffer, &[48, 13, 10, 13, 10]) { break;
189 }
190 buffer = (&buffer[buffer.len()-5..]).to_vec();
191 }
192 Ok(count)
193}
194
195pub async fn relay_sized_stream<I, O>(input: &mut I, output: &mut O, length: usize) -> Result<usize, Error>
201 where
202 I: Read + Unpin,
203 O: Write + Unpin,
204{
205 if length == 0 {
206 return Ok(0);
207 }
208
209 let mut count = 0;
210 loop {
211 let mut bytes = [0u8; 1024];
212 let size = match input.read(&mut bytes).await {
213 Ok(size) => size,
214 Err(_) => return Err(Error::StreamNotReadable),
215 };
216 let bytes = &mut bytes[0..size].to_vec();
217 count += size;
218
219 write_to_stream(output, &bytes).await?;
220 flush_stream(output).await?;
221
222 if size == 0 || count == length {
223 break;
224 } else if count > length {
225 return Err(Error::SizeLimitExceeded(length));
226 }
227 }
228 Ok(count)
229}
230
231pub async fn write_to_stream<S>(stream: &mut S, data: &[u8]) -> Result<usize, Error>
232 where
233 S: Write + Unpin,
234{
235 match stream.write(data).await {
236 Ok(size) => Ok(size),
237 Err(_) => Err(Error::StreamNotWritable),
238 }
239}
240
241pub async fn flush_stream<S>(stream: &mut S) -> Result<(), Error>
242 where
243 S: Write + Unpin,
244{
245 match stream.flush().await {
246 Ok(_) => Ok(()),
247 Err(_) => Err(Error::StreamNotWritable),
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[async_std::test]
256 async fn reads_chinked_stream() {
257 let stream = String::from("6\r\nHello \r\n6\r\nWorld!\r\n0\r\n\r\n");
258 let mut stream = stream.as_bytes();
259 let mut source = Vec::new();
260 read_chunked_stream(&mut stream, &mut source, None).await.unwrap();
261 assert_eq!(String::from_utf8(source).unwrap(), "Hello World!");
262 }
263
264 #[async_std::test]
265 async fn checks_vector_has_sequence() {
266 assert!(vec_has_sequence(&[0x0D, 0x0A, 0x0D, 0x0A], &[0x0D, 0x0A, 0x0D, 0x0A]));
267 assert!(vec_has_sequence(&[1, 4, 6, 10, 21, 5, 150], &[10, 21, 5]));
268 assert!(!vec_has_sequence(&[1, 4, 6, 10, 21, 5, 150], &[10, 5]));
269 }
270}