async_httype/
response.rs

1use std::fmt;
2use std::collections::HashMap;
3use std::collections::hash_map::RandomState;
4use async_std::io::{Read, Write};
5use crate::{Error, read_protocol_lines, write_to_stream, flush_stream};
6
7#[derive(Debug)]
8pub struct Response {
9    status_code: Option<usize>,
10    status_message: Option<String>,
11    version: Option<String>,
12    headers: HashMap<String, String>,
13    length: usize,
14    length_limit: Option<usize>,
15    lines: Vec<String>,
16}
17
18impl Response {
19
20    pub fn new() -> Self {
21        Self {
22            status_code: None,
23            status_message: None,
24            version: None,
25            headers: HashMap::with_hasher(RandomState::new()),
26            length: 0,
27            length_limit: None,
28            lines: Vec::new(),
29        }
30    }
31
32    pub fn status_code(&self) -> &Option<usize> {
33        &self.status_code
34    }
35
36    pub fn status_message(&self) -> &Option<String> {
37        &self.status_message
38    }
39
40    pub fn version(&self) -> &Option<String> {
41        &self.version
42    }
43
44    pub fn headers(&self) -> &HashMap<String, String> {
45        &self.headers
46    }
47
48    pub fn header<N: Into<String>>(&self, name: N) -> Option<&String> {
49        self.headers.get(&name.into())
50    }
51
52    pub fn length(&self) -> usize {
53        self.length
54    }
55
56    pub fn length_limit(&self) -> Option<usize> {
57        self.length_limit
58    }
59
60    pub fn lines(&self) -> &Vec<String> {
61        &self.lines
62    }
63
64    pub fn has_status_code(&self) -> bool {
65        self.status_code.is_some()
66    }
67
68    pub fn has_status_message(&self) -> bool {
69        self.status_message.is_some()
70    }
71
72    pub fn has_version(&self) -> bool {
73        self.version.is_some()
74    }
75
76    pub fn has_headers(&self) -> bool {
77        !self.headers.is_empty()
78    }
79
80    pub fn has_header<N: Into<String>>(&self, name: N) -> bool {
81        self.headers.contains_key(&name.into())
82    }
83
84    pub fn has_length_limit(&self) -> bool {
85        self.length_limit.is_some()
86    }
87
88    pub fn set_status_code(&mut self, value: usize) {
89        self.status_code = Some(value);
90    }
91
92    pub fn set_status_message<V: Into<String>>(&mut self, value: V) {
93        self.status_message = Some(value.into());
94    }
95
96    pub fn set_version<V: Into<String>>(&mut self, value: V) {
97        self.version = Some(value.into());
98    }
99
100    pub fn set_header<N: Into<String>, V: Into<String>>(&mut self, name: N, value: V) {
101        self.headers.insert(name.into(), value.into());
102    }
103
104    pub fn set_length_limit(&mut self, limit: usize) {
105        self.length_limit = Some(limit);
106    }
107
108    pub fn remove_status_code(&mut self) {
109        self.status_code = None;
110    }
111
112    pub fn remove_status_message(&mut self) {
113        self.status_message = None;
114    }
115
116    pub fn remove_version<V: Into<String>>(&mut self) {
117        self.version = None;
118    }
119
120    pub fn remove_header<N: Into<String>>(&mut self, name: N) {
121        self.headers.remove(&name.into());
122    }
123
124    pub fn remove_length_limit(&mut self) {
125        self.length_limit = None;
126    }
127
128    pub async fn read<I>(&mut self, stream: &mut I) -> Result<usize, Error>
129        where
130        I: Read + Unpin,
131    {
132        let limit = match self.length_limit {
133            Some(limit) => match limit == 0 {
134                true => return Err(Error::SizeLimitExceeded(limit)),
135                false => Some(limit - self.length),
136            },
137            None => None,
138        };
139
140        let length = read_protocol_lines(stream, &mut self.lines, limit).await?;
141        self.length += length;
142
143        Ok(length)
144    }
145
146    pub async fn write<I>(&mut self, stream: &mut I) -> Result<usize, Error>
147        where
148        I: Write + Unpin,
149    {
150        let size = write_to_stream(stream, &self.to_bytes()).await?;
151        flush_stream(stream).await?;
152        Ok(size)
153    }
154
155    pub fn clear(&mut self) {
156        self.status_code = None;
157        self.version = None;
158        self.headers.clear();
159        self.length = 0;
160        self.length_limit = None;
161        self.lines.clear();
162    }
163
164    pub fn parse_head(&mut self) -> Result<(), Error> {
165        let mut parts = match self.lines.first() {
166            Some(head) => head.splitn(3, " "),
167            None => return Err(Error::InvalidData),
168        };
169
170        self.version = match parts.next() {
171            Some(version) => match version {
172                "HTTP/1.0" => Some(String::from("1.0")),
173                "HTTP/1.1" => Some(String::from("1.1")),
174                _ => return Err(Error::InvalidData),
175            },
176            None => return Err(Error::InvalidData),
177        };
178        self.status_code = match parts.next() {
179            Some(status_code) => match status_code.parse::<usize>() {
180                Ok(status_code) => Some(status_code),
181                Err(_) => return Err(Error::InvalidData),
182            },
183            None => return Err(Error::InvalidData),
184        };
185
186        Ok(())
187    }
188
189    pub fn parse_headers(&mut self) -> Result<(), Error> {
190        for line in self.lines.iter().skip(1) {
191            if line == "" {
192                break;
193            }
194            let mut parts = line.splitn(2, ": ");
195            let name = match parts.next() {
196                Some(name) => String::from(name),
197                None => return Err(Error::InvalidData),
198            };
199            let value = match parts.next() {
200                Some(value) => String::from(value),
201                None => return Err(Error::InvalidData),
202            };
203            self.headers.insert(name, value);
204        }
205
206        Ok(())
207    }
208
209    pub fn build_head(&mut self) -> Result<(), Error> {
210        let version = match &self.version {
211            Some(version) => format!("HTTP/{}", version),
212            None => return Err(Error::InvalidData),
213        };
214        let status_code = match &self.status_code {
215            Some(code) => code,
216            None => return Err(Error::InvalidData),
217        };
218        let status_message = match &self.status_message {
219            Some(message) => message,
220            None => return Err(Error::InvalidData),
221        };
222
223        let head = format!("{} {} {}", version, status_code, status_message);
224        if self.lines.is_empty() {
225            self.lines.push(head);
226        } else {
227            self.lines[0] = head;
228        }
229
230        Ok(())
231    }
232
233    pub fn build_headers(&mut self) -> Result<(), Error> {
234        let head = match self.lines.first() {
235            Some(head) => Some(head.clone()),
236            None => None,
237        };
238        self.lines.clear();
239        if head.is_some() {
240            self.lines.push(head.unwrap());
241        }
242
243        for (name, value) in &self.headers {
244            self.lines.push(format!("{}: {}", name, value));
245        }
246
247        Ok(())
248    }
249
250    pub fn to_bytes(&self) -> Vec<u8> {
251        self.to_string().as_bytes().to_vec()
252    }
253
254    pub fn to_string(&self) -> String {
255        self.lines.join("\r\n") + "\r\n\r\n"
256    }
257}
258
259impl fmt::Display for Response {
260    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
261        write!(fmt, "{}", self.to_string())
262    }
263}
264
265impl From<Response> for String {
266    fn from(item: Response) -> String {
267        item.to_string()
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274    
275    #[async_std::test]
276    async fn writes_to_stream() {
277        let mut stream = Vec::new();
278        let mut req = Response::new();
279        req.set_status_code(200);
280        req.set_status_message("OK");
281        req.set_version("1.1");
282        req.build_head().unwrap();
283        req.write(&mut stream).await.unwrap();
284        assert_eq!(String::from_utf8(stream).unwrap(), "HTTP/1.1 200 OK\r\n\r\n");
285    }
286}