1use std::fmt;
2use std::collections::HashMap;
3use std::collections::hash_map::RandomState;
4use async_std::io::{Read, Write};
5use crate::{Error, read_protocol_lines, write_to_stream, flush_stream};
6
7#[derive(Debug)]
8pub struct Response {
9 status_code: Option<usize>,
10 status_message: Option<String>,
11 version: Option<String>,
12 headers: HashMap<String, String>,
13 length: usize,
14 length_limit: Option<usize>,
15 lines: Vec<String>,
16}
17
18impl Response {
19
20 pub fn new() -> Self {
21 Self {
22 status_code: None,
23 status_message: None,
24 version: None,
25 headers: HashMap::with_hasher(RandomState::new()),
26 length: 0,
27 length_limit: None,
28 lines: Vec::new(),
29 }
30 }
31
32 pub fn status_code(&self) -> &Option<usize> {
33 &self.status_code
34 }
35
36 pub fn status_message(&self) -> &Option<String> {
37 &self.status_message
38 }
39
40 pub fn version(&self) -> &Option<String> {
41 &self.version
42 }
43
44 pub fn headers(&self) -> &HashMap<String, String> {
45 &self.headers
46 }
47
48 pub fn header<N: Into<String>>(&self, name: N) -> Option<&String> {
49 self.headers.get(&name.into())
50 }
51
52 pub fn length(&self) -> usize {
53 self.length
54 }
55
56 pub fn length_limit(&self) -> Option<usize> {
57 self.length_limit
58 }
59
60 pub fn lines(&self) -> &Vec<String> {
61 &self.lines
62 }
63
64 pub fn has_status_code(&self) -> bool {
65 self.status_code.is_some()
66 }
67
68 pub fn has_status_message(&self) -> bool {
69 self.status_message.is_some()
70 }
71
72 pub fn has_version(&self) -> bool {
73 self.version.is_some()
74 }
75
76 pub fn has_headers(&self) -> bool {
77 !self.headers.is_empty()
78 }
79
80 pub fn has_header<N: Into<String>>(&self, name: N) -> bool {
81 self.headers.contains_key(&name.into())
82 }
83
84 pub fn has_length_limit(&self) -> bool {
85 self.length_limit.is_some()
86 }
87
88 pub fn set_status_code(&mut self, value: usize) {
89 self.status_code = Some(value);
90 }
91
92 pub fn set_status_message<V: Into<String>>(&mut self, value: V) {
93 self.status_message = Some(value.into());
94 }
95
96 pub fn set_version<V: Into<String>>(&mut self, value: V) {
97 self.version = Some(value.into());
98 }
99
100 pub fn set_header<N: Into<String>, V: Into<String>>(&mut self, name: N, value: V) {
101 self.headers.insert(name.into(), value.into());
102 }
103
104 pub fn set_length_limit(&mut self, limit: usize) {
105 self.length_limit = Some(limit);
106 }
107
108 pub fn remove_status_code(&mut self) {
109 self.status_code = None;
110 }
111
112 pub fn remove_status_message(&mut self) {
113 self.status_message = None;
114 }
115
116 pub fn remove_version<V: Into<String>>(&mut self) {
117 self.version = None;
118 }
119
120 pub fn remove_header<N: Into<String>>(&mut self, name: N) {
121 self.headers.remove(&name.into());
122 }
123
124 pub fn remove_length_limit(&mut self) {
125 self.length_limit = None;
126 }
127
128 pub async fn read<I>(&mut self, stream: &mut I) -> Result<usize, Error>
129 where
130 I: Read + Unpin,
131 {
132 let limit = match self.length_limit {
133 Some(limit) => match limit == 0 {
134 true => return Err(Error::SizeLimitExceeded(limit)),
135 false => Some(limit - self.length),
136 },
137 None => None,
138 };
139
140 let length = read_protocol_lines(stream, &mut self.lines, limit).await?;
141 self.length += length;
142
143 Ok(length)
144 }
145
146 pub async fn write<I>(&mut self, stream: &mut I) -> Result<usize, Error>
147 where
148 I: Write + Unpin,
149 {
150 let size = write_to_stream(stream, &self.to_bytes()).await?;
151 flush_stream(stream).await?;
152 Ok(size)
153 }
154
155 pub fn clear(&mut self) {
156 self.status_code = None;
157 self.version = None;
158 self.headers.clear();
159 self.length = 0;
160 self.length_limit = None;
161 self.lines.clear();
162 }
163
164 pub fn parse_head(&mut self) -> Result<(), Error> {
165 let mut parts = match self.lines.first() {
166 Some(head) => head.splitn(3, " "),
167 None => return Err(Error::InvalidData),
168 };
169
170 self.version = match parts.next() {
171 Some(version) => match version {
172 "HTTP/1.0" => Some(String::from("1.0")),
173 "HTTP/1.1" => Some(String::from("1.1")),
174 _ => return Err(Error::InvalidData),
175 },
176 None => return Err(Error::InvalidData),
177 };
178 self.status_code = match parts.next() {
179 Some(status_code) => match status_code.parse::<usize>() {
180 Ok(status_code) => Some(status_code),
181 Err(_) => return Err(Error::InvalidData),
182 },
183 None => return Err(Error::InvalidData),
184 };
185
186 Ok(())
187 }
188
189 pub fn parse_headers(&mut self) -> Result<(), Error> {
190 for line in self.lines.iter().skip(1) {
191 if line == "" {
192 break;
193 }
194 let mut parts = line.splitn(2, ": ");
195 let name = match parts.next() {
196 Some(name) => String::from(name),
197 None => return Err(Error::InvalidData),
198 };
199 let value = match parts.next() {
200 Some(value) => String::from(value),
201 None => return Err(Error::InvalidData),
202 };
203 self.headers.insert(name, value);
204 }
205
206 Ok(())
207 }
208
209 pub fn build_head(&mut self) -> Result<(), Error> {
210 let version = match &self.version {
211 Some(version) => format!("HTTP/{}", version),
212 None => return Err(Error::InvalidData),
213 };
214 let status_code = match &self.status_code {
215 Some(code) => code,
216 None => return Err(Error::InvalidData),
217 };
218 let status_message = match &self.status_message {
219 Some(message) => message,
220 None => return Err(Error::InvalidData),
221 };
222
223 let head = format!("{} {} {}", version, status_code, status_message);
224 if self.lines.is_empty() {
225 self.lines.push(head);
226 } else {
227 self.lines[0] = head;
228 }
229
230 Ok(())
231 }
232
233 pub fn build_headers(&mut self) -> Result<(), Error> {
234 let head = match self.lines.first() {
235 Some(head) => Some(head.clone()),
236 None => None,
237 };
238 self.lines.clear();
239 if head.is_some() {
240 self.lines.push(head.unwrap());
241 }
242
243 for (name, value) in &self.headers {
244 self.lines.push(format!("{}: {}", name, value));
245 }
246
247 Ok(())
248 }
249
250 pub fn to_bytes(&self) -> Vec<u8> {
251 self.to_string().as_bytes().to_vec()
252 }
253
254 pub fn to_string(&self) -> String {
255 self.lines.join("\r\n") + "\r\n\r\n"
256 }
257}
258
259impl fmt::Display for Response {
260 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
261 write!(fmt, "{}", self.to_string())
262 }
263}
264
265impl From<Response> for String {
266 fn from(item: Response) -> String {
267 item.to_string()
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274
275 #[async_std::test]
276 async fn writes_to_stream() {
277 let mut stream = Vec::new();
278 let mut req = Response::new();
279 req.set_status_code(200);
280 req.set_status_message("OK");
281 req.set_version("1.1");
282 req.build_head().unwrap();
283 req.write(&mut stream).await.unwrap();
284 assert_eq!(String::from_utf8(stream).unwrap(), "HTTP/1.1 200 OK\r\n\r\n");
285 }
286}