1use std::io::{Error, ErrorKind};
2use async_std::prelude::*;
3use async_std::io::{Read, Write};
4use crate::{read_chunk_line, write_slice, flush_write};
5
6pub async fn relay_exact<I, O>(input: &mut I, output: &mut O, length: usize) -> Result<usize, Error>
7 where
8 I: Read + Unpin,
9 O: Write + Unpin,
10{
11 if length == 0 {
12 return Ok(0);
13 }
14
15 let bufsize = 1024;
16 let mut total = 0;
17
18 loop {
19 let bufsize = match length - total < bufsize {
20 true => length - total, false => bufsize,
22 };
23
24 let mut bytes = vec![0u8; bufsize];
25 let size = input.read(&mut bytes).await?;
26 total += size;
27
28 write_slice(output, &bytes).await?;
29 flush_write(output).await?;
30
31 if total == length {
32 break;
33 }
34 }
35
36 Ok(total)
37}
38
39pub async fn relay_chunks<I, O>(input: &mut I, output: &mut O, limits: (Option<usize>, Option<usize>)) -> Result<usize, Error>
40 where
41 I: Read + Unpin,
42 O: Write + Unpin,
43{
44 let (chunklimit, datalimit) = limits;
45 let mut length = 0;
46 let mut total = 0; loop {
49 let (mut size, mut ext) = (vec![], vec![]);
50 read_chunk_line(input, (&mut size, &mut ext), chunklimit).await?;
51
52 length += write_slice(output, &size).await?;
53 if !ext.is_empty() {
54 length += write_slice(output, b";").await?;
55 length += write_slice(output, &ext).await?;
56 }
57 length += write_slice(output, b"\r\n").await?;
58
59 let size = match String::from_utf8(size) {
60 Ok(length) => match i64::from_str_radix(&length, 16) {
61 Ok(length) => length as usize,
62 Err(e) => return Err(Error::new(ErrorKind::InvalidData, e.to_string())),
63 },
64 Err(e) => return Err(Error::new(ErrorKind::InvalidData, e.to_string())),
65 };
66
67 if size == 0 {
68 length += relay_exact(input, output, 2).await?;
69 break; } else if datalimit.is_some() && total + size > datalimit.unwrap() {
71 return Err(Error::new(ErrorKind::InvalidData, format!("The operation hit the limit of {} bytes while relaying chunked HTTP body.", datalimit.unwrap())));
72 } else {
73 total += size;
74 length += relay_exact(input, output, size).await?;
75 length += relay_exact(input, output, 2).await?;
76 }
77 }
78
79 Ok(length)
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85
86 #[async_std::test]
87 async fn relays_exact() {
88 let mut output = Vec::new();
89 let size = relay_exact(&mut "0123456789".as_bytes(), &mut output, 5).await.unwrap();
90 assert_eq!(size, 5);
91 assert_eq!(output, b"01234");
92 }
93
94 #[async_std::test]
95 async fn relays_chunks() {
96 let mut output = Vec::new();
97 let size = relay_chunks(&mut "6\r\nHello \r\n6;ex;ey\r\nWorld!\r\n0\r\n\r\nFoo: bar\r\n\r\n".as_bytes(), &mut output, (None, None)).await.unwrap();
98 assert_eq!(size, 33);
99 assert_eq!(output, "6\r\nHello \r\n6;ex;ey\r\nWorld!\r\n0\r\n\r\n".as_bytes());
100 let mut output = Vec::new();
101 let exceeds = relay_chunks(&mut "3\r\nHel\r\n0;ex;".as_bytes(), &mut output, (None, Some(2))).await;
102 assert!(exceeds.is_err());
103 }
104}