async_httype/
request.rs

1use std::fmt;
2use std::collections::HashMap;
3use std::collections::hash_map::RandomState;
4use async_std::io::{Read, Write};
5use crate::{Error, read_protocol_lines, write_to_stream, flush_stream};
6
7#[derive(Debug)]
8pub struct Request {
9    method: Option<String>,
10    uri: Option<String>,
11    version: Option<String>,
12    headers: HashMap<String, String>,
13    length: usize,
14    length_limit: Option<usize>,
15    lines: Vec<String>,
16}
17
18impl Request {
19
20    pub fn new() -> Self {
21        Self {
22            method: None,
23            uri: None,
24            version: None,
25            headers: HashMap::with_hasher(RandomState::new()),
26            length: 0,
27            length_limit: None,
28            lines: Vec::new(),
29        }
30    }
31
32    pub fn method(&self) -> &Option<String> {
33        &self.method
34    }
35
36    pub fn uri(&self) -> &Option<String> {
37        &self.uri
38    }
39
40    pub fn version(&self) -> &Option<String> {
41        &self.version
42    }
43
44    pub fn headers(&self) -> &HashMap<String, String> {
45        &self.headers
46    }
47
48    pub fn header<N: Into<String>>(&self, name: N) -> Option<&String> {
49        self.headers.get(&name.into())
50    }
51
52    pub fn length(&self) -> usize {
53        self.length
54    }
55
56    pub fn length_limit(&self) -> Option<usize> {
57        self.length_limit
58    }
59
60    pub fn has_method(&self) -> bool {
61        self.method.is_some()
62    }
63
64    pub fn has_uri(&self) -> bool {
65        self.uri.is_some()
66    }
67
68    pub fn has_version(&self) -> bool {
69        self.version.is_some()
70    }
71
72    pub fn has_headers(&self) -> bool {
73        !self.headers.is_empty()
74    }
75
76    pub fn has_header<N: Into<String>>(&self, name: N) -> bool {
77        self.headers.contains_key(&name.into())
78    }
79
80    pub fn has_length_limit(&self) -> bool {
81        self.length_limit.is_some()
82    }
83
84    pub fn set_method<V: Into<String>>(&mut self, value: V) {
85        self.method = Some(value.into());
86    }
87
88    pub fn set_uri<V: Into<String>>(&mut self, value: V) {
89        self.uri = Some(value.into());
90    }
91
92    pub fn set_version<V: Into<String>>(&mut self, value: V) {
93        self.version = Some(value.into());
94    }
95
96    pub fn set_header<N: Into<String>, V: Into<String>>(&mut self, name: N, value: V) {
97        self.headers.insert(name.into(), value.into());
98    }
99
100    pub fn set_length_limit(&mut self, limit: usize) {
101        self.length_limit = Some(limit);
102    }
103
104    pub fn remove_method(&mut self) {
105        self.method = None;
106    }
107
108    pub fn remove_uri(&mut self) {
109        self.uri = None;
110    }
111
112    pub fn remove_version<V: Into<String>>(&mut self) {
113        self.version = None;
114    }
115
116    pub fn remove_header<N: Into<String>>(&mut self, name: N) {
117        self.headers.remove(&name.into());
118    }
119
120    pub fn remove_length_limit(&mut self) {
121        self.length_limit = None;
122    }
123
124    pub async fn read<I>(&mut self, stream: &mut I) -> Result<usize, Error>
125        where
126        I: Read + Unpin,
127    {
128        let limit = match self.length_limit {
129            Some(limit) => match limit == 0 {
130                true => return Err(Error::SizeLimitExceeded(limit)),
131                false => Some(limit - self.length),
132            },
133            None => None,
134        };
135
136        let length = read_protocol_lines(stream, &mut self.lines, limit).await?;
137        self.length += length;
138
139        Ok(length)
140    }
141
142    pub async fn write<I>(&mut self, stream: &mut I) -> Result<usize, Error>
143        where
144        I: Write + Unpin,
145    {
146        let size = write_to_stream(stream, &self.to_bytes()).await?;
147        flush_stream(stream).await?;
148        Ok(size)
149    }
150
151    pub fn clear(&mut self) {
152        self.method = None;
153        self.uri = None;
154        self.version = None;
155        self.headers.clear();
156        self.length = 0;
157        self.length_limit = None;
158        self.lines.clear();
159    }
160
161    pub fn parse_head(&mut self) -> Result<(), Error> {
162        let mut parts = match self.lines.first() {
163            Some(head) => head.splitn(3, " "),
164            None => return Err(Error::InvalidData),
165        };
166
167        self.method = match parts.next() {
168            Some(method) => Some(String::from(method)),
169            None => return Err(Error::InvalidData),
170        };
171        self.uri = match parts.next() {
172            Some(uri) => Some(String::from(uri)),
173            None => return Err(Error::InvalidData),
174        };
175        self.version = match parts.next() {
176            Some(version) => match version {
177                "HTTP/1.0" => Some(String::from("1.0")),
178                "HTTP/1.1" => Some(String::from("1.1")),
179                _ => return Err(Error::InvalidData),
180            },
181            None => return Err(Error::InvalidData),
182        };
183
184        Ok(())
185    }
186
187    pub fn parse_headers(&mut self) -> Result<(), Error> {
188        for line in self.lines.iter().skip(1) {
189            let mut parts = line.splitn(2, ": ");
190            let name = match parts.next() {
191                Some(name) => String::from(name),
192                None => return Err(Error::InvalidData),
193            };
194            let value = match parts.next() {
195                Some(value) => String::from(value),
196                None => return Err(Error::InvalidData),
197            };
198            self.headers.insert(name, value);
199        }
200
201        Ok(())
202    }
203
204    pub fn build_head(&mut self) -> Result<(), Error> {
205        let method = match &self.method {
206            Some(method) => method,
207            None => return Err(Error::InvalidData),
208        };
209        let uri = match &self.uri {
210            Some(uri) => uri,
211            None => return Err(Error::InvalidData),
212        };
213        let version = match &self.version {
214            Some(version) => format!("HTTP/{}", version),
215            None => return Err(Error::InvalidData),
216        };
217
218        let head = format!("{} {} {}", method, uri, version);
219        if self.lines.is_empty() {
220            self.lines.push(head);
221        } else {
222            self.lines[0] = head;
223        }
224
225        Ok(())
226    }
227
228    pub fn build_headers(&mut self) -> Result<(), Error> {
229        let head = match self.lines.first() {
230            Some(head) => Some(head.clone()),
231            None => None,
232        };
233
234        self.lines.clear();
235        if head.is_some() {
236            self.lines.push(head.unwrap());
237        }
238
239        for (name, value) in &self.headers {
240            self.lines.push(format!("{}: {}", name, value));
241        }
242
243        Ok(())
244    }
245
246    pub fn to_bytes(&self) -> Vec<u8> {
247        self.to_string().as_bytes().to_vec()
248    }
249
250    pub fn to_string(&self) -> String {
251        self.lines.join("\r\n") + "\r\n\r\n"
252    }
253}
254
255impl fmt::Display for Request {
256    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
257        write!(fmt, "{}", self.to_string())
258    }
259}
260
261impl From<Request> for String {
262    fn from(item: Request) -> String {
263        item.to_string()
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270    
271    #[async_std::test]
272    async fn writes_to_stream() {
273        let mut stream = Vec::new();
274        let mut req = Request::new();
275        req.set_method("POST");
276        req.set_uri("/foo");
277        req.set_version("1.1");
278        req.build_head().unwrap();
279        req.write(&mut stream).await.unwrap();
280        assert_eq!(String::from_utf8(stream).unwrap(), "POST /foo HTTP/1.1\r\n\r\n");
281    }
282}