1use std::fmt;
2use std::collections::HashMap;
3use std::collections::hash_map::RandomState;
4use async_std::io::{Read, Write};
5use crate::{Error, read_protocol_lines, write_to_stream, flush_stream};
6
7#[derive(Debug)]
8pub struct Request {
9 method: Option<String>,
10 uri: Option<String>,
11 version: Option<String>,
12 headers: HashMap<String, String>,
13 length: usize,
14 length_limit: Option<usize>,
15 lines: Vec<String>,
16}
17
18impl Request {
19
20 pub fn new() -> Self {
21 Self {
22 method: None,
23 uri: None,
24 version: None,
25 headers: HashMap::with_hasher(RandomState::new()),
26 length: 0,
27 length_limit: None,
28 lines: Vec::new(),
29 }
30 }
31
32 pub fn method(&self) -> &Option<String> {
33 &self.method
34 }
35
36 pub fn uri(&self) -> &Option<String> {
37 &self.uri
38 }
39
40 pub fn version(&self) -> &Option<String> {
41 &self.version
42 }
43
44 pub fn headers(&self) -> &HashMap<String, String> {
45 &self.headers
46 }
47
48 pub fn header<N: Into<String>>(&self, name: N) -> Option<&String> {
49 self.headers.get(&name.into())
50 }
51
52 pub fn length(&self) -> usize {
53 self.length
54 }
55
56 pub fn length_limit(&self) -> Option<usize> {
57 self.length_limit
58 }
59
60 pub fn has_method(&self) -> bool {
61 self.method.is_some()
62 }
63
64 pub fn has_uri(&self) -> bool {
65 self.uri.is_some()
66 }
67
68 pub fn has_version(&self) -> bool {
69 self.version.is_some()
70 }
71
72 pub fn has_headers(&self) -> bool {
73 !self.headers.is_empty()
74 }
75
76 pub fn has_header<N: Into<String>>(&self, name: N) -> bool {
77 self.headers.contains_key(&name.into())
78 }
79
80 pub fn has_length_limit(&self) -> bool {
81 self.length_limit.is_some()
82 }
83
84 pub fn set_method<V: Into<String>>(&mut self, value: V) {
85 self.method = Some(value.into());
86 }
87
88 pub fn set_uri<V: Into<String>>(&mut self, value: V) {
89 self.uri = Some(value.into());
90 }
91
92 pub fn set_version<V: Into<String>>(&mut self, value: V) {
93 self.version = Some(value.into());
94 }
95
96 pub fn set_header<N: Into<String>, V: Into<String>>(&mut self, name: N, value: V) {
97 self.headers.insert(name.into(), value.into());
98 }
99
100 pub fn set_length_limit(&mut self, limit: usize) {
101 self.length_limit = Some(limit);
102 }
103
104 pub fn remove_method(&mut self) {
105 self.method = None;
106 }
107
108 pub fn remove_uri(&mut self) {
109 self.uri = None;
110 }
111
112 pub fn remove_version<V: Into<String>>(&mut self) {
113 self.version = None;
114 }
115
116 pub fn remove_header<N: Into<String>>(&mut self, name: N) {
117 self.headers.remove(&name.into());
118 }
119
120 pub fn remove_length_limit(&mut self) {
121 self.length_limit = None;
122 }
123
124 pub async fn read<I>(&mut self, stream: &mut I) -> Result<usize, Error>
125 where
126 I: Read + Unpin,
127 {
128 let limit = match self.length_limit {
129 Some(limit) => match limit == 0 {
130 true => return Err(Error::SizeLimitExceeded(limit)),
131 false => Some(limit - self.length),
132 },
133 None => None,
134 };
135
136 let length = read_protocol_lines(stream, &mut self.lines, limit).await?;
137 self.length += length;
138
139 Ok(length)
140 }
141
142 pub async fn write<I>(&mut self, stream: &mut I) -> Result<usize, Error>
143 where
144 I: Write + Unpin,
145 {
146 let size = write_to_stream(stream, &self.to_bytes()).await?;
147 flush_stream(stream).await?;
148 Ok(size)
149 }
150
151 pub fn clear(&mut self) {
152 self.method = None;
153 self.uri = None;
154 self.version = None;
155 self.headers.clear();
156 self.length = 0;
157 self.length_limit = None;
158 self.lines.clear();
159 }
160
161 pub fn parse_head(&mut self) -> Result<(), Error> {
162 let mut parts = match self.lines.first() {
163 Some(head) => head.splitn(3, " "),
164 None => return Err(Error::InvalidData),
165 };
166
167 self.method = match parts.next() {
168 Some(method) => Some(String::from(method)),
169 None => return Err(Error::InvalidData),
170 };
171 self.uri = match parts.next() {
172 Some(uri) => Some(String::from(uri)),
173 None => return Err(Error::InvalidData),
174 };
175 self.version = match parts.next() {
176 Some(version) => match version {
177 "HTTP/1.0" => Some(String::from("1.0")),
178 "HTTP/1.1" => Some(String::from("1.1")),
179 _ => return Err(Error::InvalidData),
180 },
181 None => return Err(Error::InvalidData),
182 };
183
184 Ok(())
185 }
186
187 pub fn parse_headers(&mut self) -> Result<(), Error> {
188 for line in self.lines.iter().skip(1) {
189 let mut parts = line.splitn(2, ": ");
190 let name = match parts.next() {
191 Some(name) => String::from(name),
192 None => return Err(Error::InvalidData),
193 };
194 let value = match parts.next() {
195 Some(value) => String::from(value),
196 None => return Err(Error::InvalidData),
197 };
198 self.headers.insert(name, value);
199 }
200
201 Ok(())
202 }
203
204 pub fn build_head(&mut self) -> Result<(), Error> {
205 let method = match &self.method {
206 Some(method) => method,
207 None => return Err(Error::InvalidData),
208 };
209 let uri = match &self.uri {
210 Some(uri) => uri,
211 None => return Err(Error::InvalidData),
212 };
213 let version = match &self.version {
214 Some(version) => format!("HTTP/{}", version),
215 None => return Err(Error::InvalidData),
216 };
217
218 let head = format!("{} {} {}", method, uri, version);
219 if self.lines.is_empty() {
220 self.lines.push(head);
221 } else {
222 self.lines[0] = head;
223 }
224
225 Ok(())
226 }
227
228 pub fn build_headers(&mut self) -> Result<(), Error> {
229 let head = match self.lines.first() {
230 Some(head) => Some(head.clone()),
231 None => None,
232 };
233
234 self.lines.clear();
235 if head.is_some() {
236 self.lines.push(head.unwrap());
237 }
238
239 for (name, value) in &self.headers {
240 self.lines.push(format!("{}: {}", name, value));
241 }
242
243 Ok(())
244 }
245
246 pub fn to_bytes(&self) -> Vec<u8> {
247 self.to_string().as_bytes().to_vec()
248 }
249
250 pub fn to_string(&self) -> String {
251 self.lines.join("\r\n") + "\r\n\r\n"
252 }
253}
254
255impl fmt::Display for Request {
256 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
257 write!(fmt, "{}", self.to_string())
258 }
259}
260
261impl From<Request> for String {
262 fn from(item: Request) -> String {
263 item.to_string()
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[async_std::test]
272 async fn writes_to_stream() {
273 let mut stream = Vec::new();
274 let mut req = Request::new();
275 req.set_method("POST");
276 req.set_uri("/foo");
277 req.set_version("1.1");
278 req.build_head().unwrap();
279 req.write(&mut stream).await.unwrap();
280 assert_eq!(String::from_utf8(stream).unwrap(), "POST /foo HTTP/1.1\r\n\r\n");
281 }
282}