dbs_uhttp/
response.rs

1// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::io::{Error as WriteError, Write};
5
6use crate::ascii::{COLON, CR, LF, SP};
7use crate::common::{Body, Version};
8use crate::headers::{Header, MediaType};
9use crate::Method;
10
11/// Wrapper over a response status code.
12///
13/// The status code is defined as specified in the
14/// [RFC](https://tools.ietf.org/html/rfc7231#section-6).
15#[derive(Clone, Copy, Debug, Eq, PartialEq)]
16pub enum StatusCode {
17    /// 100, Continue
18    Continue,
19    /// 200, OK
20    OK,
21    /// 204, No Content
22    NoContent,
23    /// 400, Bad Request
24    BadRequest,
25    /// 401, Unauthorized
26    Unauthorized,
27    /// 404, Not Found
28    NotFound,
29    /// 405, Method Not Allowed
30    MethodNotAllowed,
31    /// 413, Payload Too Large
32    PayloadTooLarge,
33    /// 500, Internal Server Error
34    InternalServerError,
35    /// 501, Not Implemented
36    NotImplemented,
37    /// 503, Service Unavailable
38    ServiceUnavailable,
39}
40
41impl StatusCode {
42    /// Returns the status code as bytes.
43    pub fn raw(self) -> &'static [u8; 3] {
44        match self {
45            Self::Continue => b"100",
46            Self::OK => b"200",
47            Self::NoContent => b"204",
48            Self::BadRequest => b"400",
49            Self::Unauthorized => b"401",
50            Self::NotFound => b"404",
51            Self::MethodNotAllowed => b"405",
52            Self::PayloadTooLarge => b"413",
53            Self::InternalServerError => b"500",
54            Self::NotImplemented => b"501",
55            Self::ServiceUnavailable => b"503",
56        }
57    }
58}
59
60#[derive(Debug, Eq, PartialEq)]
61struct StatusLine {
62    http_version: Version,
63    status_code: StatusCode,
64}
65
66impl StatusLine {
67    fn new(http_version: Version, status_code: StatusCode) -> Self {
68        Self {
69            http_version,
70            status_code,
71        }
72    }
73
74    fn write_all<T: Write>(&self, mut buf: T) -> Result<(), WriteError> {
75        buf.write_all(self.http_version.raw())?;
76        buf.write_all(&[SP])?;
77        buf.write_all(self.status_code.raw())?;
78        buf.write_all(&[SP, CR, LF])?;
79
80        Ok(())
81    }
82}
83
84/// Wrapper over the list of headers associated with a HTTP Response.
85/// When creating a ResponseHeaders object, the content type is initialized to `text/plain`.
86/// The content type can be updated with a call to `set_content_type`.
87#[derive(Debug, Eq, PartialEq)]
88pub struct ResponseHeaders {
89    content_length: i32,
90    content_type: MediaType,
91    deprecation: bool,
92    server: String,
93    allow: Vec<Method>,
94    accept_encoding: bool,
95}
96
97impl Default for ResponseHeaders {
98    fn default() -> Self {
99        Self {
100            content_length: Default::default(),
101            content_type: Default::default(),
102            deprecation: false,
103            server: String::from("Firecracker API"),
104            allow: Vec::new(),
105            accept_encoding: false,
106        }
107    }
108}
109
110impl ResponseHeaders {
111    // The logic pertaining to `Allow` header writing.
112    fn write_allow_header<T: Write>(&self, buf: &mut T) -> Result<(), WriteError> {
113        if self.allow.is_empty() {
114            return Ok(());
115        }
116
117        buf.write_all(b"Allow: ")?;
118
119        let delimitator = b", ";
120        for (idx, method) in self.allow.iter().enumerate() {
121            buf.write_all(method.raw())?;
122            // We check above that `self.allow` is not empty.
123            if idx < self.allow.len() - 1 {
124                buf.write_all(delimitator)?;
125            }
126        }
127
128        buf.write_all(&[CR, LF])
129    }
130
131    // The logic pertaining to `Deprecation` header writing.
132    fn write_deprecation_header<T: Write>(&self, buf: &mut T) -> Result<(), WriteError> {
133        if !self.deprecation {
134            return Ok(());
135        }
136
137        buf.write_all(b"Deprecation: true")?;
138        buf.write_all(&[CR, LF])
139    }
140
141    /// Writes the headers to `buf` using the HTTP specification.
142    pub fn write_all<T: Write>(&self, buf: &mut T) -> Result<(), WriteError> {
143        buf.write_all(Header::Server.raw())?;
144        buf.write_all(&[COLON, SP])?;
145        buf.write_all(self.server.as_bytes())?;
146        buf.write_all(&[CR, LF])?;
147
148        buf.write_all(b"Connection: keep-alive")?;
149        buf.write_all(&[CR, LF])?;
150
151        self.write_allow_header(buf)?;
152        self.write_deprecation_header(buf)?;
153
154        if self.content_length != 0 {
155            buf.write_all(Header::ContentType.raw())?;
156            buf.write_all(&[COLON, SP])?;
157            buf.write_all(self.content_type.as_str().as_bytes())?;
158            buf.write_all(&[CR, LF])?;
159
160            buf.write_all(Header::ContentLength.raw())?;
161            buf.write_all(&[COLON, SP])?;
162            buf.write_all(self.content_length.to_string().as_bytes())?;
163            buf.write_all(&[CR, LF])?;
164
165            if self.accept_encoding {
166                buf.write_all(Header::AcceptEncoding.raw())?;
167                buf.write_all(&[COLON, SP])?;
168                buf.write_all(b"identity")?;
169                buf.write_all(&[CR, LF])?;
170            }
171        }
172
173        buf.write_all(&[CR, LF])
174    }
175
176    // Sets the content length to be written in the HTTP response.
177    fn set_content_length(&mut self, content_length: i32) {
178        self.content_length = content_length;
179    }
180
181    /// Sets the HTTP response header server.
182    pub fn set_server(&mut self, server: &str) {
183        self.server = String::from(server);
184    }
185
186    /// Sets the content type to be written in the HTTP response.
187    pub fn set_content_type(&mut self, content_type: MediaType) {
188        self.content_type = content_type;
189    }
190
191    /// Sets the `Deprecation` header to be written in the HTTP response.
192    /// https://tools.ietf.org/id/draft-dalal-deprecation-header-03.html
193    #[allow(unused)]
194    pub fn set_deprecation(&mut self) {
195        self.deprecation = true;
196    }
197
198    /// Sets the encoding type to be written in the HTTP response.
199    #[allow(unused)]
200    pub fn set_encoding(&mut self) {
201        self.accept_encoding = true;
202    }
203}
204
205/// Wrapper over an HTTP Response.
206///
207/// The Response is created using a `Version` and a `StatusCode`. When creating a Response object,
208/// the body is initialized to `None` and the header is initialized with the `default` value. The body
209/// can be updated with a call to `set_body`. The header can be updated with `set_content_type` and
210/// `set_server`.
211#[derive(Debug, Eq, PartialEq)]
212pub struct Response {
213    status_line: StatusLine,
214    headers: ResponseHeaders,
215    body: Option<Body>,
216}
217
218impl Response {
219    /// Creates a new HTTP `Response` with an empty body.
220    pub fn new(http_version: Version, status_code: StatusCode) -> Self {
221        Self {
222            status_line: StatusLine::new(http_version, status_code),
223            headers: ResponseHeaders::default(),
224            body: Default::default(),
225        }
226    }
227
228    /// Updates the body of the `Response`.
229    ///
230    /// This function has side effects because it also updates the headers:
231    /// - `ContentLength`: this is set to the length of the specified body.
232    pub fn set_body(&mut self, body: Body) {
233        self.headers.set_content_length(body.len() as i32);
234        self.body = Some(body);
235    }
236
237    /// Updates the content type of the `Response`.
238    pub fn set_content_type(&mut self, content_type: MediaType) {
239        self.headers.set_content_type(content_type);
240    }
241
242    /// Marks the `Response` as deprecated.
243    pub fn set_deprecation(&mut self) {
244        self.headers.set_deprecation();
245    }
246
247    /// Updates the encoding type of `Response`.
248    pub fn set_encoding(&mut self) {
249        self.headers.set_encoding();
250    }
251
252    /// Sets the HTTP response server.
253    pub fn set_server(&mut self, server: &str) {
254        self.headers.set_server(server);
255    }
256
257    /// Sets the HTTP allowed methods.
258    pub fn set_allow(&mut self, methods: Vec<Method>) {
259        self.headers.allow = methods;
260    }
261
262    /// Allows a specific HTTP method.
263    pub fn allow_method(&mut self, method: Method) {
264        self.headers.allow.push(method);
265    }
266
267    fn write_body<T: Write>(&self, mut buf: T) -> Result<(), WriteError> {
268        if let Some(ref body) = self.body {
269            buf.write_all(body.raw())?;
270        }
271        Ok(())
272    }
273
274    /// Writes the content of the `Response` to the specified `buf`.
275    ///
276    /// # Errors
277    /// Returns an error when the buffer is not large enough.
278    pub fn write_all<T: Write>(&self, mut buf: &mut T) -> Result<(), WriteError> {
279        self.status_line.write_all(&mut buf)?;
280        self.headers.write_all(&mut buf)?;
281        self.write_body(&mut buf)?;
282
283        Ok(())
284    }
285
286    /// Returns the Status Code of the Response.
287    pub fn status(&self) -> StatusCode {
288        self.status_line.status_code
289    }
290
291    /// Returns the Body of the response. If the response does not have a body,
292    /// it returns None.
293    pub fn body(&self) -> Option<Body> {
294        self.body.clone()
295    }
296
297    /// Returns the Content Length of the response.
298    pub fn content_length(&self) -> i32 {
299        self.headers.content_length
300    }
301
302    /// Returns the Content Type of the response.
303    pub fn content_type(&self) -> MediaType {
304        self.headers.content_type
305    }
306
307    /// Returns the deprecation status of the response.
308    pub fn deprecation(&self) -> bool {
309        self.headers.deprecation
310    }
311
312    /// Returns the HTTP Version of the response.
313    pub fn http_version(&self) -> Version {
314        self.status_line.http_version
315    }
316
317    /// Returns the allowed HTTP methods.
318    pub fn allow(&self) -> Vec<Method> {
319        self.headers.allow.clone()
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326
327    #[test]
328    fn test_write_response() {
329        let mut response = Response::new(Version::Http10, StatusCode::OK);
330        let body = "This is a test";
331        response.set_body(Body::new(body));
332        response.set_content_type(MediaType::PlainText);
333        response.set_encoding();
334
335        assert_eq!(response.status(), StatusCode::OK);
336        assert_eq!(response.body().unwrap(), Body::new(body));
337        assert_eq!(response.http_version(), Version::Http10);
338        assert_eq!(response.content_length(), 14);
339        assert_eq!(response.content_type(), MediaType::PlainText);
340
341        let expected_response: &'static [u8] = b"HTTP/1.0 200 \r\n\
342            Server: Firecracker API\r\n\
343            Connection: keep-alive\r\n\
344            Content-Type: text/plain\r\n\
345            Content-Length: 14\r\n\
346            Accept-Encoding: identity\r\n\r\n\
347            This is a test";
348
349        let mut response_buf: [u8; 153] = [0; 153];
350        assert!(response.write_all(&mut response_buf.as_mut()).is_ok());
351        assert_eq!(response_buf.as_ref(), expected_response);
352
353        // Test response `Allow` header.
354        let mut response = Response::new(Version::Http10, StatusCode::OK);
355        let allowed_methods = vec![Method::Get, Method::Patch, Method::Put];
356        response.set_allow(allowed_methods.clone());
357        assert_eq!(response.allow(), allowed_methods);
358
359        let expected_response: &'static [u8] = b"HTTP/1.0 200 \r\n\
360            Server: Firecracker API\r\n\
361            Connection: keep-alive\r\n\
362            Allow: GET, PATCH, PUT\r\n\r\n";
363        let mut response_buf: [u8; 90] = [0; 90];
364        assert!(response.write_all(&mut response_buf.as_mut()).is_ok());
365        assert_eq!(response_buf.as_ref(), expected_response);
366
367        // Test write failed.
368        let mut response_buf: [u8; 1] = [0; 1];
369        assert!(response.write_all(&mut response_buf.as_mut()).is_err());
370    }
371
372    #[test]
373    fn test_set_server() {
374        let mut response = Response::new(Version::Http10, StatusCode::OK);
375        let body = "This is a test";
376        let server = "rust-vmm API";
377        response.set_body(Body::new(body));
378        response.set_content_type(MediaType::PlainText);
379        response.set_server(server);
380
381        assert_eq!(response.status(), StatusCode::OK);
382        assert_eq!(response.body().unwrap(), Body::new(body));
383        assert_eq!(response.http_version(), Version::Http10);
384        assert_eq!(response.content_length(), 14);
385        assert_eq!(response.content_type(), MediaType::PlainText);
386
387        let expected_response = format!(
388            "HTTP/1.0 200 \r\n\
389             Server: {}\r\n\
390             Connection: keep-alive\r\n\
391             Content-Type: text/plain\r\n\
392             Content-Length: 14\r\n\r\n\
393             This is a test",
394            server
395        );
396
397        let mut response_buf: [u8; 123] = [0; 123];
398        assert!(response.write_all(&mut response_buf.as_mut()).is_ok());
399        assert_eq!(response_buf.as_ref(), expected_response.as_bytes());
400    }
401
402    #[test]
403    fn test_status_code() {
404        assert_eq!(StatusCode::Continue.raw(), b"100");
405        assert_eq!(StatusCode::OK.raw(), b"200");
406        assert_eq!(StatusCode::NoContent.raw(), b"204");
407        assert_eq!(StatusCode::BadRequest.raw(), b"400");
408        assert_eq!(StatusCode::Unauthorized.raw(), b"401");
409        assert_eq!(StatusCode::NotFound.raw(), b"404");
410        assert_eq!(StatusCode::MethodNotAllowed.raw(), b"405");
411        assert_eq!(StatusCode::PayloadTooLarge.raw(), b"413");
412        assert_eq!(StatusCode::InternalServerError.raw(), b"500");
413        assert_eq!(StatusCode::NotImplemented.raw(), b"501");
414        assert_eq!(StatusCode::ServiceUnavailable.raw(), b"503");
415    }
416
417    #[test]
418    fn test_allow_method() {
419        let mut response = Response::new(Version::Http10, StatusCode::MethodNotAllowed);
420        response.allow_method(Method::Get);
421        response.allow_method(Method::Put);
422        assert_eq!(response.allow(), vec![Method::Get, Method::Put]);
423    }
424
425    #[test]
426    fn test_deprecation() {
427        // Test a deprecated response with body.
428        let mut response = Response::new(Version::Http10, StatusCode::OK);
429        let body = "This is a test";
430        response.set_body(Body::new(body));
431        response.set_content_type(MediaType::PlainText);
432        response.set_encoding();
433        response.set_deprecation();
434
435        assert_eq!(response.status(), StatusCode::OK);
436        assert_eq!(response.body().unwrap(), Body::new(body));
437        assert_eq!(response.http_version(), Version::Http10);
438        assert_eq!(response.content_length(), 14);
439        assert_eq!(response.content_type(), MediaType::PlainText);
440        assert!(response.deprecation());
441
442        let expected_response: &'static [u8] = b"HTTP/1.0 200 \r\n\
443            Server: Firecracker API\r\n\
444            Connection: keep-alive\r\n\
445            Deprecation: true\r\n\
446            Content-Type: text/plain\r\n\
447            Content-Length: 14\r\n\
448            Accept-Encoding: identity\r\n\r\n\
449            This is a test";
450
451        let mut response_buf: [u8; 172] = [0; 172];
452        assert!(response.write_all(&mut response_buf.as_mut()).is_ok());
453        assert_eq!(response_buf.as_ref(), expected_response);
454
455        // Test a deprecated response without a body.
456        let mut response = Response::new(Version::Http10, StatusCode::NoContent);
457        response.set_deprecation();
458
459        assert_eq!(response.status(), StatusCode::NoContent);
460        assert_eq!(response.http_version(), Version::Http10);
461        assert!(response.deprecation());
462
463        let expected_response: &'static [u8] = b"HTTP/1.0 204 \r\n\
464            Server: Firecracker API\r\n\
465            Connection: keep-alive\r\n\
466            Deprecation: true\r\n\r\n";
467
468        let mut response_buf: [u8; 85] = [0; 85];
469        assert!(response.write_all(&mut response_buf.as_mut()).is_ok());
470        assert_eq!(response_buf.as_ref(), expected_response);
471    }
472
473    #[test]
474    fn test_equal() {
475        let response = Response::new(Version::Http10, StatusCode::MethodNotAllowed);
476        let another_response = Response::new(Version::Http10, StatusCode::MethodNotAllowed);
477        assert_eq!(response, another_response);
478
479        let response = Response::new(Version::Http10, StatusCode::OK);
480        let another_response = Response::new(Version::Http10, StatusCode::BadRequest);
481        assert_ne!(response, another_response);
482    }
483}