use crate::error::ServerError;
use serde::{Deserialize, Serialize};
use std::io::{BufWriter, Write};
#[doc(alias = "http response")]
#[derive(
Clone, Debug, PartialEq, Eq, Hash, Default, Serialize, Deserialize,
)]
pub struct Response {
pub status_code: u16,
pub status_text: String,
pub headers: Vec<(String, String)>,
pub body: Vec<u8>,
}
impl Response {
#[doc(alias = "constructor")]
pub fn new(
status_code: u16,
status_text: &str,
body: Vec<u8>,
) -> Self {
Response {
status_code,
status_text: status_text.to_string(),
headers: Vec::new(),
body,
}
}
#[doc(alias = "set header")]
pub fn add_header(&mut self, name: &str, value: &str) {
self.headers.push((name.to_string(), value.to_string()));
}
pub fn set_connection_header(&mut self, value: &str) {
self.headers.retain(|(name, _)| {
!name.eq_ignore_ascii_case("connection")
});
self.headers
.push(("Connection".to_string(), value.to_string()));
}
#[doc(alias = "serialize")]
#[doc(alias = "write response")]
pub fn send<W: Write>(
&self,
stream: &mut W,
) -> Result<(), ServerError> {
let mut w = BufWriter::with_capacity(4096, stream);
let mut has_content_length = false;
let mut has_connection = false;
write!(
w,
"HTTP/1.1 {} {}\r\n",
self.status_code, self.status_text
)?;
for (name, value) in &self.headers {
if name.eq_ignore_ascii_case("content-length") {
has_content_length = true;
}
if name.eq_ignore_ascii_case("connection") {
has_connection = true;
}
write!(w, "{}: {}\r\n", name, value)?;
}
if !has_content_length {
write!(w, "Content-Length: {}\r\n", self.body.len())?;
}
if !has_connection {
w.write_all(b"Connection: close\r\n")?;
}
w.write_all(b"\r\n")?;
w.write_all(&self.body)?;
w.flush()?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::{self, Cursor, Write};
#[test]
fn test_response_new() {
let status_code = 200;
let status_text = "OK";
let body = b"Hello, world!".to_vec();
let response =
Response::new(status_code, status_text, body.clone());
assert_eq!(response.status_code, status_code);
assert_eq!(response.status_text, status_text.to_string());
assert!(response.headers.is_empty());
assert_eq!(response.body, body);
}
#[test]
fn test_response_add_header() {
let mut response = Response::new(200, "OK", vec![]);
response.add_header("Content-Type", "text/html");
assert_eq!(response.headers.len(), 1);
assert_eq!(
response.headers[0],
("Content-Type".to_string(), "text/html".to_string())
);
}
struct MockTcpStream {
buffer: Cursor<Vec<u8>>,
}
impl MockTcpStream {
fn new() -> Self {
MockTcpStream {
buffer: Cursor::new(Vec::new()),
}
}
fn get_written_data(&self) -> Vec<u8> {
self.buffer.clone().into_inner()
}
}
impl Write for MockTcpStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.buffer.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.buffer.flush()
}
}
#[test]
fn test_response_send() {
let mut response =
Response::new(200, "OK", b"Hello, world!".to_vec());
response.add_header("Content-Type", "text/plain");
let mut mock_stream = MockTcpStream::new();
let result = response.send(&mut mock_stream);
assert!(result.is_ok());
let expected_output = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nContent-Length: 13\r\nConnection: close\r\n\r\nHello, world!";
let written_data = mock_stream.get_written_data();
assert_eq!(written_data, expected_output);
}
#[test]
fn test_response_send_error() {
let mut response =
Response::new(200, "OK", b"Hello, world!".to_vec());
response.add_header("Content-Type", "text/plain");
struct FailingStream;
impl Write for FailingStream {
fn write(&mut self, _buf: &[u8]) -> io::Result<usize> {
Err(io::Error::other("write error"))
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
let mut failing_stream = FailingStream;
let result = response.send(&mut failing_stream);
failing_stream.flush().expect("flush");
assert!(result.is_err());
}
#[test]
fn test_response_send_propagates_status_line_overflow_error() {
let huge_status = "X".repeat(8 * 1024);
let response = Response::new(200, &huge_status, b"".to_vec());
struct FailingStream;
impl Write for FailingStream {
fn write(&mut self, _buf: &[u8]) -> io::Result<usize> {
Err(io::Error::other("write error"))
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
let mut sink = FailingStream;
let err = response.send(&mut sink).expect_err("must fail");
assert!(err.to_string().contains("write error"));
sink.flush().expect("flush always Ok");
}
}