1use std::collections::HashMap;
2use async_std::io::{Read, Write};
3use crate::{Error, relay_chunked_stream, relay_sized_stream};
4
5#[derive(Debug)]
6pub struct Relay {
7 length: usize,
8 length_limit: Option<usize>,
9}
10
11impl Relay {
12
13 pub fn new() -> Self {
14 Self {
15 length: 0,
16 length_limit: None,
17 }
18 }
19
20 pub fn length(&self) -> usize {
21 self.length
22 }
23
24 pub fn length_limit(&self) -> Option<usize> {
25 self.length_limit
26 }
27
28 pub fn has_length_limit(&self) -> bool {
29 self.length_limit.is_some()
30 }
31
32 pub fn set_length_limit(&mut self, limit: usize) {
33 self.length_limit = Some(limit);
34 }
35
36 pub fn remove_length_limit(&mut self) {
37 self.length_limit = None;
38 }
39
40 pub async fn relay<I, O>(&mut self, input: &mut I, output: &mut O, req: &HashMap<String, String>) -> Result<usize, Error>
41 where
42 I: Write + Read + Unpin,
43 O: Write + Read + Unpin,
44 {
45 let length = req.get("Content-Length");
46 let encoding = req.get("Transfer-Encoding");
47
48 if encoding.is_some() && encoding.unwrap().contains(&String::from("chunked")) {
49 self.relay_chunked(input, output).await
50 } else {
51 let length = match length {
52 Some(length) => match length.parse::<usize>() {
53 Ok(length) => length,
54 Err(_) => return Err(Error::InvalidHeader(String::from("Content-Length"))),
55 },
56 None => return Err(Error::InvalidHeader(String::from("Content-Length"))),
57 };
58 self.relay_sized(input, output, length).await
59 }
60 }
61
62 pub async fn relay_chunked<I, O>(&mut self, input: &mut I, output: &mut O) -> Result<usize, Error>
63 where
64 I: Write + Read + Unpin,
65 O: Write + Read + Unpin,
66 {
67 let limit = match self.length_limit {
68 Some(limit) => match limit == 0 {
69 true => return Err(Error::SizeLimitExceeded(limit)),
70 false => Some(limit - self.length),
71 },
72 None => None,
73 };
74
75 let length = relay_chunked_stream(input, output, limit).await?;
76 self.length += length;
77
78 Ok(length)
79 }
80
81 pub async fn relay_sized<I, O>(&mut self, input: &mut I, output: &mut O, length: usize) -> Result<usize, Error>
82 where
83 I: Read + Unpin,
84 O: Write + Unpin,
85 {
86 match self.length_limit {
87 Some(limit) => match length + self.length > limit {
88 true => return Err(Error::SizeLimitExceeded(limit)),
89 false => (),
90 },
91 None => (),
92 };
93
94 let length = relay_sized_stream(input, output, length).await?;
95 self.length += length;
96
97 Ok(length)
98 }
99
100 pub fn clear(&mut self) {
101 self.length = 0;
102 self.length_limit = None;
103 }
104}