use std::io::{Error as WriteError, Write};
use crate::ascii::{COLON, CR, LF, SP};
use crate::common::{Body, Version};
use crate::headers::{Header, MediaType};
use crate::Method;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum StatusCode {
Continue,
OK,
NoContent,
BadRequest,
Unauthorized,
NotFound,
MethodNotAllowed,
PayloadTooLarge,
InternalServerError,
NotImplemented,
ServiceUnavailable,
}
impl StatusCode {
pub fn raw(self) -> &'static [u8; 3] {
match self {
Self::Continue => b"100",
Self::OK => b"200",
Self::NoContent => b"204",
Self::BadRequest => b"400",
Self::Unauthorized => b"401",
Self::NotFound => b"404",
Self::MethodNotAllowed => b"405",
Self::PayloadTooLarge => b"413",
Self::InternalServerError => b"500",
Self::NotImplemented => b"501",
Self::ServiceUnavailable => b"503",
}
}
}
#[derive(Debug, Eq, PartialEq)]
struct StatusLine {
http_version: Version,
status_code: StatusCode,
}
impl StatusLine {
fn new(http_version: Version, status_code: StatusCode) -> Self {
Self {
http_version,
status_code,
}
}
fn write_all<T: Write>(&self, mut buf: T) -> Result<(), WriteError> {
buf.write_all(self.http_version.raw())?;
buf.write_all(&[SP])?;
buf.write_all(self.status_code.raw())?;
buf.write_all(&[SP, CR, LF])?;
Ok(())
}
}
#[derive(Debug, Eq, PartialEq)]
pub struct ResponseHeaders {
content_length: i32,
content_type: MediaType,
deprecation: bool,
server: String,
allow: Vec<Method>,
accept_encoding: bool,
}
impl Default for ResponseHeaders {
fn default() -> Self {
Self {
content_length: Default::default(),
content_type: Default::default(),
deprecation: false,
server: String::from("Firecracker API"),
allow: Vec::new(),
accept_encoding: false,
}
}
}
impl ResponseHeaders {
fn write_allow_header<T: Write>(&self, buf: &mut T) -> Result<(), WriteError> {
if self.allow.is_empty() {
return Ok(());
}
buf.write_all(b"Allow: ")?;
let delimitator = b", ";
for (idx, method) in self.allow.iter().enumerate() {
buf.write_all(method.raw())?;
if idx < self.allow.len() - 1 {
buf.write_all(delimitator)?;
}
}
buf.write_all(&[CR, LF])
}
fn write_deprecation_header<T: Write>(&self, buf: &mut T) -> Result<(), WriteError> {
if !self.deprecation {
return Ok(());
}
buf.write_all(b"Deprecation: true")?;
buf.write_all(&[CR, LF])
}
pub fn write_all<T: Write>(&self, buf: &mut T) -> Result<(), WriteError> {
buf.write_all(Header::Server.raw())?;
buf.write_all(&[COLON, SP])?;
buf.write_all(self.server.as_bytes())?;
buf.write_all(&[CR, LF])?;
buf.write_all(b"Connection: keep-alive")?;
buf.write_all(&[CR, LF])?;
self.write_allow_header(buf)?;
self.write_deprecation_header(buf)?;
if self.content_length != 0 {
buf.write_all(Header::ContentType.raw())?;
buf.write_all(&[COLON, SP])?;
buf.write_all(self.content_type.as_str().as_bytes())?;
buf.write_all(&[CR, LF])?;
buf.write_all(Header::ContentLength.raw())?;
buf.write_all(&[COLON, SP])?;
buf.write_all(self.content_length.to_string().as_bytes())?;
buf.write_all(&[CR, LF])?;
if self.accept_encoding {
buf.write_all(Header::AcceptEncoding.raw())?;
buf.write_all(&[COLON, SP])?;
buf.write_all(b"identity")?;
buf.write_all(&[CR, LF])?;
}
}
buf.write_all(&[CR, LF])
}
fn set_content_length(&mut self, content_length: i32) {
self.content_length = content_length;
}
pub fn set_server(&mut self, server: &str) {
self.server = String::from(server);
}
pub fn set_content_type(&mut self, content_type: MediaType) {
self.content_type = content_type;
}
#[allow(unused)]
pub fn set_deprecation(&mut self) {
self.deprecation = true;
}
#[allow(unused)]
pub fn set_encoding(&mut self) {
self.accept_encoding = true;
}
}
#[derive(Debug, Eq, PartialEq)]
pub struct Response {
status_line: StatusLine,
headers: ResponseHeaders,
body: Option<Body>,
}
impl Response {
pub fn new(http_version: Version, status_code: StatusCode) -> Self {
Self {
status_line: StatusLine::new(http_version, status_code),
headers: ResponseHeaders::default(),
body: Default::default(),
}
}
pub fn set_body(&mut self, body: Body) {
self.headers.set_content_length(body.len() as i32);
self.body = Some(body);
}
pub fn set_content_type(&mut self, content_type: MediaType) {
self.headers.set_content_type(content_type);
}
pub fn set_deprecation(&mut self) {
self.headers.set_deprecation();
}
pub fn set_encoding(&mut self) {
self.headers.set_encoding();
}
pub fn set_server(&mut self, server: &str) {
self.headers.set_server(server);
}
pub fn set_allow(&mut self, methods: Vec<Method>) {
self.headers.allow = methods;
}
pub fn allow_method(&mut self, method: Method) {
self.headers.allow.push(method);
}
fn write_body<T: Write>(&self, mut buf: T) -> Result<(), WriteError> {
if let Some(ref body) = self.body {
buf.write_all(body.raw())?;
}
Ok(())
}
pub fn write_all<T: Write>(&self, mut buf: &mut T) -> Result<(), WriteError> {
self.status_line.write_all(&mut buf)?;
self.headers.write_all(&mut buf)?;
self.write_body(&mut buf)?;
Ok(())
}
pub fn status(&self) -> StatusCode {
self.status_line.status_code
}
pub fn body(&self) -> Option<Body> {
self.body.clone()
}
pub fn content_length(&self) -> i32 {
self.headers.content_length
}
pub fn content_type(&self) -> MediaType {
self.headers.content_type
}
pub fn deprecation(&self) -> bool {
self.headers.deprecation
}
pub fn http_version(&self) -> Version {
self.status_line.http_version
}
pub fn allow(&self) -> Vec<Method> {
self.headers.allow.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_write_response() {
let mut response = Response::new(Version::Http10, StatusCode::OK);
let body = "This is a test";
response.set_body(Body::new(body));
response.set_content_type(MediaType::PlainText);
response.set_encoding();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.body().unwrap(), Body::new(body));
assert_eq!(response.http_version(), Version::Http10);
assert_eq!(response.content_length(), 14);
assert_eq!(response.content_type(), MediaType::PlainText);
let expected_response: &'static [u8] = b"HTTP/1.0 200 \r\n\
Server: Firecracker API\r\n\
Connection: keep-alive\r\n\
Content-Type: text/plain\r\n\
Content-Length: 14\r\n\
Accept-Encoding: identity\r\n\r\n\
This is a test";
let mut response_buf: [u8; 153] = [0; 153];
assert!(response.write_all(&mut response_buf.as_mut()).is_ok());
assert_eq!(response_buf.as_ref(), expected_response);
let mut response = Response::new(Version::Http10, StatusCode::OK);
let allowed_methods = vec![Method::Get, Method::Patch, Method::Put];
response.set_allow(allowed_methods.clone());
assert_eq!(response.allow(), allowed_methods);
let expected_response: &'static [u8] = b"HTTP/1.0 200 \r\n\
Server: Firecracker API\r\n\
Connection: keep-alive\r\n\
Allow: GET, PATCH, PUT\r\n\r\n";
let mut response_buf: [u8; 90] = [0; 90];
assert!(response.write_all(&mut response_buf.as_mut()).is_ok());
assert_eq!(response_buf.as_ref(), expected_response);
let mut response_buf: [u8; 1] = [0; 1];
assert!(response.write_all(&mut response_buf.as_mut()).is_err());
}
#[test]
fn test_set_server() {
let mut response = Response::new(Version::Http10, StatusCode::OK);
let body = "This is a test";
let server = "rust-vmm API";
response.set_body(Body::new(body));
response.set_content_type(MediaType::PlainText);
response.set_server(server);
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.body().unwrap(), Body::new(body));
assert_eq!(response.http_version(), Version::Http10);
assert_eq!(response.content_length(), 14);
assert_eq!(response.content_type(), MediaType::PlainText);
let expected_response = format!(
"HTTP/1.0 200 \r\n\
Server: {}\r\n\
Connection: keep-alive\r\n\
Content-Type: text/plain\r\n\
Content-Length: 14\r\n\r\n\
This is a test",
server
);
let mut response_buf: [u8; 123] = [0; 123];
assert!(response.write_all(&mut response_buf.as_mut()).is_ok());
assert_eq!(response_buf.as_ref(), expected_response.as_bytes());
}
#[test]
fn test_status_code() {
assert_eq!(StatusCode::Continue.raw(), b"100");
assert_eq!(StatusCode::OK.raw(), b"200");
assert_eq!(StatusCode::NoContent.raw(), b"204");
assert_eq!(StatusCode::BadRequest.raw(), b"400");
assert_eq!(StatusCode::Unauthorized.raw(), b"401");
assert_eq!(StatusCode::NotFound.raw(), b"404");
assert_eq!(StatusCode::MethodNotAllowed.raw(), b"405");
assert_eq!(StatusCode::PayloadTooLarge.raw(), b"413");
assert_eq!(StatusCode::InternalServerError.raw(), b"500");
assert_eq!(StatusCode::NotImplemented.raw(), b"501");
assert_eq!(StatusCode::ServiceUnavailable.raw(), b"503");
}
#[test]
fn test_allow_method() {
let mut response = Response::new(Version::Http10, StatusCode::MethodNotAllowed);
response.allow_method(Method::Get);
response.allow_method(Method::Put);
assert_eq!(response.allow(), vec![Method::Get, Method::Put]);
}
#[test]
fn test_deprecation() {
let mut response = Response::new(Version::Http10, StatusCode::OK);
let body = "This is a test";
response.set_body(Body::new(body));
response.set_content_type(MediaType::PlainText);
response.set_encoding();
response.set_deprecation();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.body().unwrap(), Body::new(body));
assert_eq!(response.http_version(), Version::Http10);
assert_eq!(response.content_length(), 14);
assert_eq!(response.content_type(), MediaType::PlainText);
assert!(response.deprecation());
let expected_response: &'static [u8] = b"HTTP/1.0 200 \r\n\
Server: Firecracker API\r\n\
Connection: keep-alive\r\n\
Deprecation: true\r\n\
Content-Type: text/plain\r\n\
Content-Length: 14\r\n\
Accept-Encoding: identity\r\n\r\n\
This is a test";
let mut response_buf: [u8; 172] = [0; 172];
assert!(response.write_all(&mut response_buf.as_mut()).is_ok());
assert_eq!(response_buf.as_ref(), expected_response);
let mut response = Response::new(Version::Http10, StatusCode::NoContent);
response.set_deprecation();
assert_eq!(response.status(), StatusCode::NoContent);
assert_eq!(response.http_version(), Version::Http10);
assert!(response.deprecation());
let expected_response: &'static [u8] = b"HTTP/1.0 204 \r\n\
Server: Firecracker API\r\n\
Connection: keep-alive\r\n\
Deprecation: true\r\n\r\n";
let mut response_buf: [u8; 85] = [0; 85];
assert!(response.write_all(&mut response_buf.as_mut()).is_ok());
assert_eq!(response_buf.as_ref(), expected_response);
}
#[test]
fn test_equal() {
let response = Response::new(Version::Http10, StatusCode::MethodNotAllowed);
let another_response = Response::new(Version::Http10, StatusCode::MethodNotAllowed);
assert_eq!(response, another_response);
let response = Response::new(Version::Http10, StatusCode::OK);
let another_response = Response::new(Version::Http10, StatusCode::BadRequest);
assert_ne!(response, another_response);
}
}