1use std::io::{Error as WriteError, Write};
5
6use crate::ascii::{COLON, CR, LF, SP};
7use crate::common::{Body, Version};
8use crate::headers::{Header, MediaType};
9use crate::Method;
10
11#[derive(Clone, Copy, Debug, Eq, PartialEq)]
16pub enum StatusCode {
17 Continue,
19 OK,
21 NoContent,
23 BadRequest,
25 Unauthorized,
27 NotFound,
29 MethodNotAllowed,
31 PayloadTooLarge,
33 InternalServerError,
35 NotImplemented,
37 ServiceUnavailable,
39}
40
41impl StatusCode {
42 pub fn raw(self) -> &'static [u8; 3] {
44 match self {
45 Self::Continue => b"100",
46 Self::OK => b"200",
47 Self::NoContent => b"204",
48 Self::BadRequest => b"400",
49 Self::Unauthorized => b"401",
50 Self::NotFound => b"404",
51 Self::MethodNotAllowed => b"405",
52 Self::PayloadTooLarge => b"413",
53 Self::InternalServerError => b"500",
54 Self::NotImplemented => b"501",
55 Self::ServiceUnavailable => b"503",
56 }
57 }
58}
59
60#[derive(Debug, Eq, PartialEq)]
61struct StatusLine {
62 http_version: Version,
63 status_code: StatusCode,
64}
65
66impl StatusLine {
67 fn new(http_version: Version, status_code: StatusCode) -> Self {
68 Self {
69 http_version,
70 status_code,
71 }
72 }
73
74 fn write_all<T: Write>(&self, mut buf: T) -> Result<(), WriteError> {
75 buf.write_all(self.http_version.raw())?;
76 buf.write_all(&[SP])?;
77 buf.write_all(self.status_code.raw())?;
78 buf.write_all(&[SP, CR, LF])?;
79
80 Ok(())
81 }
82}
83
84#[derive(Debug, Eq, PartialEq)]
88pub struct ResponseHeaders {
89 content_length: i32,
90 content_type: MediaType,
91 deprecation: bool,
92 server: String,
93 allow: Vec<Method>,
94 accept_encoding: bool,
95}
96
97impl Default for ResponseHeaders {
98 fn default() -> Self {
99 Self {
100 content_length: Default::default(),
101 content_type: Default::default(),
102 deprecation: false,
103 server: String::from("Firecracker API"),
104 allow: Vec::new(),
105 accept_encoding: false,
106 }
107 }
108}
109
110impl ResponseHeaders {
111 fn write_allow_header<T: Write>(&self, buf: &mut T) -> Result<(), WriteError> {
113 if self.allow.is_empty() {
114 return Ok(());
115 }
116
117 buf.write_all(b"Allow: ")?;
118
119 let delimitator = b", ";
120 for (idx, method) in self.allow.iter().enumerate() {
121 buf.write_all(method.raw())?;
122 if idx < self.allow.len() - 1 {
124 buf.write_all(delimitator)?;
125 }
126 }
127
128 buf.write_all(&[CR, LF])
129 }
130
131 fn write_deprecation_header<T: Write>(&self, buf: &mut T) -> Result<(), WriteError> {
133 if !self.deprecation {
134 return Ok(());
135 }
136
137 buf.write_all(b"Deprecation: true")?;
138 buf.write_all(&[CR, LF])
139 }
140
141 pub fn write_all<T: Write>(&self, buf: &mut T) -> Result<(), WriteError> {
143 buf.write_all(Header::Server.raw())?;
144 buf.write_all(&[COLON, SP])?;
145 buf.write_all(self.server.as_bytes())?;
146 buf.write_all(&[CR, LF])?;
147
148 buf.write_all(b"Connection: keep-alive")?;
149 buf.write_all(&[CR, LF])?;
150
151 self.write_allow_header(buf)?;
152 self.write_deprecation_header(buf)?;
153
154 if self.content_length != 0 {
155 buf.write_all(Header::ContentType.raw())?;
156 buf.write_all(&[COLON, SP])?;
157 buf.write_all(self.content_type.as_str().as_bytes())?;
158 buf.write_all(&[CR, LF])?;
159
160 buf.write_all(Header::ContentLength.raw())?;
161 buf.write_all(&[COLON, SP])?;
162 buf.write_all(self.content_length.to_string().as_bytes())?;
163 buf.write_all(&[CR, LF])?;
164
165 if self.accept_encoding {
166 buf.write_all(Header::AcceptEncoding.raw())?;
167 buf.write_all(&[COLON, SP])?;
168 buf.write_all(b"identity")?;
169 buf.write_all(&[CR, LF])?;
170 }
171 }
172
173 buf.write_all(&[CR, LF])
174 }
175
176 fn set_content_length(&mut self, content_length: i32) {
178 self.content_length = content_length;
179 }
180
181 pub fn set_server(&mut self, server: &str) {
183 self.server = String::from(server);
184 }
185
186 pub fn set_content_type(&mut self, content_type: MediaType) {
188 self.content_type = content_type;
189 }
190
191 #[allow(unused)]
194 pub fn set_deprecation(&mut self) {
195 self.deprecation = true;
196 }
197
198 #[allow(unused)]
200 pub fn set_encoding(&mut self) {
201 self.accept_encoding = true;
202 }
203}
204
205#[derive(Debug, Eq, PartialEq)]
212pub struct Response {
213 status_line: StatusLine,
214 headers: ResponseHeaders,
215 body: Option<Body>,
216}
217
218impl Response {
219 pub fn new(http_version: Version, status_code: StatusCode) -> Self {
221 Self {
222 status_line: StatusLine::new(http_version, status_code),
223 headers: ResponseHeaders::default(),
224 body: Default::default(),
225 }
226 }
227
228 pub fn set_body(&mut self, body: Body) {
233 self.headers.set_content_length(body.len() as i32);
234 self.body = Some(body);
235 }
236
237 pub fn set_content_type(&mut self, content_type: MediaType) {
239 self.headers.set_content_type(content_type);
240 }
241
242 pub fn set_deprecation(&mut self) {
244 self.headers.set_deprecation();
245 }
246
247 pub fn set_encoding(&mut self) {
249 self.headers.set_encoding();
250 }
251
252 pub fn set_server(&mut self, server: &str) {
254 self.headers.set_server(server);
255 }
256
257 pub fn set_allow(&mut self, methods: Vec<Method>) {
259 self.headers.allow = methods;
260 }
261
262 pub fn allow_method(&mut self, method: Method) {
264 self.headers.allow.push(method);
265 }
266
267 fn write_body<T: Write>(&self, mut buf: T) -> Result<(), WriteError> {
268 if let Some(ref body) = self.body {
269 buf.write_all(body.raw())?;
270 }
271 Ok(())
272 }
273
274 pub fn write_all<T: Write>(&self, mut buf: &mut T) -> Result<(), WriteError> {
279 self.status_line.write_all(&mut buf)?;
280 self.headers.write_all(&mut buf)?;
281 self.write_body(&mut buf)?;
282
283 Ok(())
284 }
285
286 pub fn status(&self) -> StatusCode {
288 self.status_line.status_code
289 }
290
291 pub fn body(&self) -> Option<Body> {
294 self.body.clone()
295 }
296
297 pub fn content_length(&self) -> i32 {
299 self.headers.content_length
300 }
301
302 pub fn content_type(&self) -> MediaType {
304 self.headers.content_type
305 }
306
307 pub fn deprecation(&self) -> bool {
309 self.headers.deprecation
310 }
311
312 pub fn http_version(&self) -> Version {
314 self.status_line.http_version
315 }
316
317 pub fn allow(&self) -> Vec<Method> {
319 self.headers.allow.clone()
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326
327 #[test]
328 fn test_write_response() {
329 let mut response = Response::new(Version::Http10, StatusCode::OK);
330 let body = "This is a test";
331 response.set_body(Body::new(body));
332 response.set_content_type(MediaType::PlainText);
333 response.set_encoding();
334
335 assert_eq!(response.status(), StatusCode::OK);
336 assert_eq!(response.body().unwrap(), Body::new(body));
337 assert_eq!(response.http_version(), Version::Http10);
338 assert_eq!(response.content_length(), 14);
339 assert_eq!(response.content_type(), MediaType::PlainText);
340
341 let expected_response: &'static [u8] = b"HTTP/1.0 200 \r\n\
342 Server: Firecracker API\r\n\
343 Connection: keep-alive\r\n\
344 Content-Type: text/plain\r\n\
345 Content-Length: 14\r\n\
346 Accept-Encoding: identity\r\n\r\n\
347 This is a test";
348
349 let mut response_buf: [u8; 153] = [0; 153];
350 assert!(response.write_all(&mut response_buf.as_mut()).is_ok());
351 assert_eq!(response_buf.as_ref(), expected_response);
352
353 let mut response = Response::new(Version::Http10, StatusCode::OK);
355 let allowed_methods = vec![Method::Get, Method::Patch, Method::Put];
356 response.set_allow(allowed_methods.clone());
357 assert_eq!(response.allow(), allowed_methods);
358
359 let expected_response: &'static [u8] = b"HTTP/1.0 200 \r\n\
360 Server: Firecracker API\r\n\
361 Connection: keep-alive\r\n\
362 Allow: GET, PATCH, PUT\r\n\r\n";
363 let mut response_buf: [u8; 90] = [0; 90];
364 assert!(response.write_all(&mut response_buf.as_mut()).is_ok());
365 assert_eq!(response_buf.as_ref(), expected_response);
366
367 let mut response_buf: [u8; 1] = [0; 1];
369 assert!(response.write_all(&mut response_buf.as_mut()).is_err());
370 }
371
372 #[test]
373 fn test_set_server() {
374 let mut response = Response::new(Version::Http10, StatusCode::OK);
375 let body = "This is a test";
376 let server = "rust-vmm API";
377 response.set_body(Body::new(body));
378 response.set_content_type(MediaType::PlainText);
379 response.set_server(server);
380
381 assert_eq!(response.status(), StatusCode::OK);
382 assert_eq!(response.body().unwrap(), Body::new(body));
383 assert_eq!(response.http_version(), Version::Http10);
384 assert_eq!(response.content_length(), 14);
385 assert_eq!(response.content_type(), MediaType::PlainText);
386
387 let expected_response = format!(
388 "HTTP/1.0 200 \r\n\
389 Server: {}\r\n\
390 Connection: keep-alive\r\n\
391 Content-Type: text/plain\r\n\
392 Content-Length: 14\r\n\r\n\
393 This is a test",
394 server
395 );
396
397 let mut response_buf: [u8; 123] = [0; 123];
398 assert!(response.write_all(&mut response_buf.as_mut()).is_ok());
399 assert_eq!(response_buf.as_ref(), expected_response.as_bytes());
400 }
401
402 #[test]
403 fn test_status_code() {
404 assert_eq!(StatusCode::Continue.raw(), b"100");
405 assert_eq!(StatusCode::OK.raw(), b"200");
406 assert_eq!(StatusCode::NoContent.raw(), b"204");
407 assert_eq!(StatusCode::BadRequest.raw(), b"400");
408 assert_eq!(StatusCode::Unauthorized.raw(), b"401");
409 assert_eq!(StatusCode::NotFound.raw(), b"404");
410 assert_eq!(StatusCode::MethodNotAllowed.raw(), b"405");
411 assert_eq!(StatusCode::PayloadTooLarge.raw(), b"413");
412 assert_eq!(StatusCode::InternalServerError.raw(), b"500");
413 assert_eq!(StatusCode::NotImplemented.raw(), b"501");
414 assert_eq!(StatusCode::ServiceUnavailable.raw(), b"503");
415 }
416
417 #[test]
418 fn test_allow_method() {
419 let mut response = Response::new(Version::Http10, StatusCode::MethodNotAllowed);
420 response.allow_method(Method::Get);
421 response.allow_method(Method::Put);
422 assert_eq!(response.allow(), vec![Method::Get, Method::Put]);
423 }
424
425 #[test]
426 fn test_deprecation() {
427 let mut response = Response::new(Version::Http10, StatusCode::OK);
429 let body = "This is a test";
430 response.set_body(Body::new(body));
431 response.set_content_type(MediaType::PlainText);
432 response.set_encoding();
433 response.set_deprecation();
434
435 assert_eq!(response.status(), StatusCode::OK);
436 assert_eq!(response.body().unwrap(), Body::new(body));
437 assert_eq!(response.http_version(), Version::Http10);
438 assert_eq!(response.content_length(), 14);
439 assert_eq!(response.content_type(), MediaType::PlainText);
440 assert!(response.deprecation());
441
442 let expected_response: &'static [u8] = b"HTTP/1.0 200 \r\n\
443 Server: Firecracker API\r\n\
444 Connection: keep-alive\r\n\
445 Deprecation: true\r\n\
446 Content-Type: text/plain\r\n\
447 Content-Length: 14\r\n\
448 Accept-Encoding: identity\r\n\r\n\
449 This is a test";
450
451 let mut response_buf: [u8; 172] = [0; 172];
452 assert!(response.write_all(&mut response_buf.as_mut()).is_ok());
453 assert_eq!(response_buf.as_ref(), expected_response);
454
455 let mut response = Response::new(Version::Http10, StatusCode::NoContent);
457 response.set_deprecation();
458
459 assert_eq!(response.status(), StatusCode::NoContent);
460 assert_eq!(response.http_version(), Version::Http10);
461 assert!(response.deprecation());
462
463 let expected_response: &'static [u8] = b"HTTP/1.0 204 \r\n\
464 Server: Firecracker API\r\n\
465 Connection: keep-alive\r\n\
466 Deprecation: true\r\n\r\n";
467
468 let mut response_buf: [u8; 85] = [0; 85];
469 assert!(response.write_all(&mut response_buf.as_mut()).is_ok());
470 assert_eq!(response_buf.as_ref(), expected_response);
471 }
472
473 #[test]
474 fn test_equal() {
475 let response = Response::new(Version::Http10, StatusCode::MethodNotAllowed);
476 let another_response = Response::new(Version::Http10, StatusCode::MethodNotAllowed);
477 assert_eq!(response, another_response);
478
479 let response = Response::new(Version::Http10, StatusCode::OK);
480 let another_response = Response::new(Version::Http10, StatusCode::BadRequest);
481 assert_ne!(response, another_response);
482 }
483}