use std::marker::PhantomData;
use std::{cell::Cell, cmp, io, io::Write, mem, ptr, ptr::copy_nonoverlapping, slice};
use crate::http::body::BodySize;
use crate::http::config::DateService;
use crate::http::header::{Value, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING};
use crate::http::helpers;
use crate::http::message::{ConnectionType, RequestHeadType};
use crate::http::response::Response;
use crate::http::{HeaderMap, StatusCode, Version};
use crate::util::{BufMut, BytesMut};
const AVERAGE_HEADER_SIZE: usize = 30;
#[derive(Debug)]
pub(super) struct MessageEncoder<T: MessageType> {
pub(super) length: BodySize,
pub(super) te: Cell<TransferEncoding>,
_t: PhantomData<T>,
}
impl<T: MessageType> Default for MessageEncoder<T> {
fn default() -> Self {
MessageEncoder {
length: BodySize::None,
te: Cell::new(TransferEncoding::empty()),
_t: PhantomData,
}
}
}
impl<T: MessageType> Clone for MessageEncoder<T> {
fn clone(&self) -> Self {
MessageEncoder {
length: self.length,
te: self.te.clone(),
_t: PhantomData,
}
}
}
pub(super) trait MessageType: Sized {
fn status(&self) -> Option<StatusCode>;
fn headers(&self) -> &HeaderMap;
fn extra_headers(&self) -> Option<&HeaderMap>;
fn chunked(&self) -> bool;
fn encode_status(&self, dst: &mut BytesMut) -> io::Result<()>;
fn encode_headers(
&self,
dst: &mut BytesMut,
version: Version,
mut length: BodySize,
ctype: ConnectionType,
timer: &DateService,
) -> io::Result<()> {
let chunked = self.chunked();
let mut skip_len = length != BodySize::Stream;
if let Some(status) = self.status() {
match status {
StatusCode::NO_CONTENT | StatusCode::CONTINUE | StatusCode::PROCESSING => {
length = BodySize::None
}
StatusCode::SWITCHING_PROTOCOLS => {
skip_len = true;
length = BodySize::Stream;
}
_ => (),
}
}
match length {
BodySize::None => dst.extend_from_slice(b"\r\n"),
BodySize::Empty => dst.extend_from_slice(b"\r\ncontent-length: 0\r\n"),
BodySize::Sized(len) => write_content_length(len, dst),
BodySize::Stream => {
if chunked {
skip_len = true;
dst.extend_from_slice(b"\r\ntransfer-encoding: chunked\r\n")
} else {
skip_len = false;
dst.extend_from_slice(b"\r\n");
}
}
}
match ctype {
ConnectionType::Upgrade => dst.extend_from_slice(b"connection: upgrade\r\n"),
ConnectionType::KeepAlive if version < Version::HTTP_11 => {
dst.extend_from_slice(b"connection: keep-alive\r\n")
}
ConnectionType::Close if version >= Version::HTTP_11 => {
dst.extend_from_slice(b"connection: close\r\n")
}
_ => (),
}
let empty_headers = HeaderMap::new();
let extra_headers = self.extra_headers().unwrap_or(&empty_headers);
let headers = self
.headers()
.iter_inner()
.filter(|(name, _)| !extra_headers.contains_key(*name))
.chain(extra_headers.iter_inner());
let mut pos = 0;
let mut has_date = false;
let mut remaining = dst.capacity() - dst.len();
let mut buf = dst.chunk_mut().as_mut_ptr() as *mut u8;
for (key, value) in headers {
match *key {
CONNECTION => continue,
TRANSFER_ENCODING | CONTENT_LENGTH if skip_len => continue,
DATE => {
has_date = true;
}
_ => (),
}
let k = key.as_str().as_bytes();
match value {
Value::One(ref val) => {
let v = val.as_ref();
let v_len = v.len();
let k_len = k.len();
let len = k_len + v_len + 4;
unsafe {
if len > remaining {
dst.advance_mut(pos);
pos = 0;
dst.reserve(len * 2);
remaining = dst.capacity() - dst.len();
buf = dst.chunk_mut().as_mut_ptr() as *mut u8;
}
copy_nonoverlapping(k.as_ptr(), buf, k_len);
buf = buf.add(k_len);
copy_nonoverlapping(b": ".as_ptr(), buf, 2);
buf = buf.add(2);
copy_nonoverlapping(v.as_ptr(), buf, v_len);
buf = buf.add(v_len);
copy_nonoverlapping(b"\r\n".as_ptr(), buf, 2);
buf = buf.add(2);
}
pos += len;
remaining -= len;
}
Value::Multi(ref vec) => {
for val in vec {
let v = val.as_ref();
let v_len = v.len();
let k_len = k.len();
let len = k_len + v_len + 4;
unsafe {
if len > remaining {
dst.advance_mut(pos);
pos = 0;
dst.reserve(len * 2);
remaining = dst.capacity() - dst.len();
buf = dst.chunk_mut().as_mut_ptr() as *mut u8;
}
copy_nonoverlapping(k.as_ptr(), buf, k_len);
buf = buf.add(k_len);
copy_nonoverlapping(b": ".as_ptr(), buf, 2);
buf = buf.add(2);
copy_nonoverlapping(v.as_ptr(), buf, v_len);
buf = buf.add(v_len);
copy_nonoverlapping(b"\r\n".as_ptr(), buf, 2);
buf = buf.add(2);
};
pos += len;
remaining -= len;
}
}
}
}
unsafe {
dst.advance_mut(pos);
}
if !has_date {
timer.set_date_header(dst);
} else {
dst.extend_from_slice(b"\r\n");
}
Ok(())
}
}
impl MessageType for Response<()> {
fn status(&self) -> Option<StatusCode> {
Some(self.head().status)
}
fn chunked(&self) -> bool {
self.head().chunked()
}
fn headers(&self) -> &HeaderMap {
&self.head().headers
}
fn extra_headers(&self) -> Option<&HeaderMap> {
None
}
fn encode_status(&self, dst: &mut BytesMut) -> io::Result<()> {
let head = self.head();
let reason = head.reason().as_bytes();
dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE + reason.len());
write_status_line(head.version, head.status.as_u16(), dst);
dst.extend_from_slice(reason);
Ok(())
}
}
impl MessageType for RequestHeadType {
fn status(&self) -> Option<StatusCode> {
None
}
fn chunked(&self) -> bool {
self.as_ref().chunked()
}
fn headers(&self) -> &HeaderMap {
self.as_ref().headers()
}
fn extra_headers(&self) -> Option<&HeaderMap> {
self.extra_headers()
}
fn encode_status(&self, dst: &mut BytesMut) -> io::Result<()> {
let head = self.as_ref();
dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE);
write!(
helpers::Writer(dst),
"{} {} {}",
head.method,
head.uri.path_and_query().map(|u| u.as_str()).unwrap_or("/"),
match head.version {
Version::HTTP_09 => "HTTP/0.9",
Version::HTTP_10 => "HTTP/1.0",
Version::HTTP_11 => "HTTP/1.1",
_ =>
return Err(io::Error::new(io::ErrorKind::Other, "unsupported version")),
}
)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
}
impl<T: MessageType> MessageEncoder<T> {
pub(super) fn encode_chunk(&self, msg: &[u8], buf: &mut BytesMut) -> io::Result<bool> {
let mut te = self.te.get();
let result = te.encode(msg, buf);
self.te.set(te);
result
}
pub(super) fn encode_eof(&self, buf: &mut BytesMut) -> io::Result<()> {
let mut te = self.te.get();
let result = te.encode_eof(buf);
self.te.set(te);
result
}
pub(super) fn encode(
&self,
dst: &mut BytesMut,
message: &mut T,
head: bool,
stream: bool,
version: Version,
length: BodySize,
ctype: ConnectionType,
timer: &DateService,
) -> io::Result<()> {
if !head {
self.te.set(match length {
BodySize::Empty => TransferEncoding::empty(),
BodySize::Sized(len) => TransferEncoding::length(len),
BodySize::Stream => {
if message.chunked() && !stream {
TransferEncoding::chunked()
} else {
TransferEncoding::eof()
}
}
BodySize::None => TransferEncoding::empty(),
});
} else {
self.te.set(TransferEncoding::empty());
}
message.encode_status(dst)?;
message.encode_headers(dst, version, length, ctype, timer)
}
}
#[derive(Debug, Copy, Clone)]
pub(super) struct TransferEncoding {
kind: TransferEncodingKind,
}
#[derive(Debug, PartialEq, Clone, Copy)]
enum TransferEncodingKind {
Chunked(bool),
Length(u64),
Eof,
}
impl TransferEncoding {
#[inline]
pub(super) fn empty() -> TransferEncoding {
TransferEncoding {
kind: TransferEncodingKind::Length(0),
}
}
#[inline]
pub(super) fn eof() -> TransferEncoding {
TransferEncoding {
kind: TransferEncodingKind::Eof,
}
}
#[inline]
pub(super) fn chunked() -> TransferEncoding {
TransferEncoding {
kind: TransferEncodingKind::Chunked(false),
}
}
#[inline]
pub(super) fn length(len: u64) -> TransferEncoding {
TransferEncoding {
kind: TransferEncodingKind::Length(len),
}
}
#[inline]
pub(super) fn encode(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result<bool> {
match self.kind {
TransferEncodingKind::Eof => {
let eof = msg.is_empty();
buf.extend_from_slice(msg);
Ok(eof)
}
TransferEncodingKind::Chunked(eof) => {
if eof {
return Ok(true);
}
let result = if msg.is_empty() {
buf.extend_from_slice(b"0\r\n\r\n");
self.kind = TransferEncodingKind::Chunked(true);
true
} else {
writeln!(helpers::Writer(buf), "{:X}\r", msg.len())
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
buf.reserve(msg.len() + 2);
buf.extend_from_slice(msg);
buf.extend_from_slice(b"\r\n");
false
};
Ok(result)
}
TransferEncodingKind::Length(mut remaining) => {
if remaining > 0 {
if msg.is_empty() {
return Ok(remaining == 0);
}
let len = cmp::min(remaining, msg.len() as u64);
buf.extend_from_slice(&msg[..len as usize]);
remaining -= len;
self.kind = TransferEncodingKind::Length(remaining);
Ok(remaining == 0)
} else {
Ok(true)
}
}
}
}
#[inline]
pub(super) fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> {
match self.kind {
TransferEncodingKind::Eof => Ok(()),
TransferEncodingKind::Length(rem) => {
if rem != 0 {
Err(io::Error::new(io::ErrorKind::UnexpectedEof, ""))
} else {
Ok(())
}
}
TransferEncodingKind::Chunked(eof) => {
if !eof {
buf.extend_from_slice(b"0\r\n\r\n");
self.kind = TransferEncodingKind::Chunked(true);
}
Ok(())
}
}
}
}
const DEC_DIGITS_LUT: &[u8] = b"0001020304050607080910111213141516171819\
2021222324252627282930313233343536373839\
4041424344454647484950515253545556575859\
6061626364656667686970717273747576777879\
8081828384858687888990919293949596979899";
const STATUS_LINE_BUF_SIZE: usize = 13;
fn write_status_line(version: Version, mut n: u16, bytes: &mut BytesMut) {
let mut buf: [u8; STATUS_LINE_BUF_SIZE] = match version {
Version::HTTP_2 => *b"HTTP/2 ",
Version::HTTP_10 => *b"HTTP/1.0 ",
Version::HTTP_09 => *b"HTTP/0.9 ",
_ => *b"HTTP/1.1 ",
};
let mut curr: isize = 12;
let buf_ptr = buf.as_mut_ptr();
let lut_ptr = DEC_DIGITS_LUT.as_ptr();
let four = n > 999;
let d1 = (n % 100) << 1;
n /= 100;
curr -= 2;
unsafe {
ptr::copy_nonoverlapping(lut_ptr.offset(d1 as isize), buf_ptr.offset(curr), 2);
if n < 10 {
curr -= 1;
*buf_ptr.offset(curr) = (n as u8) + b'0';
} else {
let d1 = n << 1;
curr -= 2;
ptr::copy_nonoverlapping(lut_ptr.offset(d1 as isize), buf_ptr.offset(curr), 2);
}
}
bytes.extend_from_slice(&buf);
if four {
bytes.put_u8(b' ');
}
}
fn write_content_length(mut n: u64, bytes: &mut BytesMut) {
if n < 10 {
let mut buf: [u8; 21] = [
b'\r', b'\n', b'c', b'o', b'n', b't', b'e', b'n', b't', b'-', b'l', b'e', b'n',
b'g', b't', b'h', b':', b' ', b'0', b'\r', b'\n',
];
buf[18] = (n as u8) + b'0';
bytes.extend_from_slice(&buf);
} else if n < 100 {
let mut buf: [u8; 22] = [
b'\r', b'\n', b'c', b'o', b'n', b't', b'e', b'n', b't', b'-', b'l', b'e', b'n',
b'g', b't', b'h', b':', b' ', b'0', b'0', b'\r', b'\n',
];
let d1 = n << 1;
unsafe {
ptr::copy_nonoverlapping(
DEC_DIGITS_LUT.as_ptr().add(d1 as usize),
buf.as_mut_ptr().offset(18),
2,
);
}
bytes.extend_from_slice(&buf);
} else if n < 1000 {
let mut buf: [u8; 23] = [
b'\r', b'\n', b'c', b'o', b'n', b't', b'e', b'n', b't', b'-', b'l', b'e', b'n',
b'g', b't', b'h', b':', b' ', b'0', b'0', b'0', b'\r', b'\n',
];
let d1 = (n % 100) << 1;
n /= 100;
unsafe {
ptr::copy_nonoverlapping(
DEC_DIGITS_LUT.as_ptr().add(d1 as usize),
buf.as_mut_ptr().offset(19),
2,
)
};
buf[18] = (n as u8) + b'0';
bytes.extend_from_slice(&buf);
} else {
bytes.extend_from_slice(b"\r\ncontent-length: ");
unsafe { convert_usize(n, bytes) };
}
}
unsafe fn convert_usize(mut n: u64, bytes: &mut BytesMut) {
let mut curr: isize = 39;
#[allow(invalid_value, clippy::uninit_assumed_init)]
let mut buf: [u8; 41] = mem::MaybeUninit::uninit().assume_init();
buf[39] = b'\r';
buf[40] = b'\n';
let buf_ptr = buf.as_mut_ptr();
let lut_ptr = DEC_DIGITS_LUT.as_ptr();
while n >= 10_000 {
let rem = (n % 10_000) as isize;
n /= 10_000;
let d1 = (rem / 100) << 1;
let d2 = (rem % 100) << 1;
curr -= 4;
ptr::copy_nonoverlapping(lut_ptr.offset(d1), buf_ptr.offset(curr), 2);
ptr::copy_nonoverlapping(lut_ptr.offset(d2), buf_ptr.offset(curr + 2), 2);
}
let mut n = n as isize;
if n >= 100 {
let d1 = (n % 100) << 1;
n /= 100;
curr -= 2;
ptr::copy_nonoverlapping(lut_ptr.offset(d1), buf_ptr.offset(curr), 2);
}
if n < 10 {
curr -= 1;
*buf_ptr.offset(curr) = (n as u8) + b'0';
} else {
let d1 = n << 1;
curr -= 2;
ptr::copy_nonoverlapping(lut_ptr.offset(d1), buf_ptr.offset(curr), 2);
}
bytes.extend_from_slice(slice::from_raw_parts(
buf_ptr.offset(curr),
41 - curr as usize,
));
}
#[cfg(test)]
mod tests {
use std::rc::Rc;
use super::*;
use crate::http::header::{HeaderValue, AUTHORIZATION};
use crate::http::RequestHead;
use crate::util::Bytes;
#[test]
fn test_chunked_te() {
let mut bytes = BytesMut::new();
let mut enc = TransferEncoding::chunked();
{
assert!(!enc.encode(b"test", &mut bytes).ok().unwrap());
assert!(enc.encode(b"", &mut bytes).ok().unwrap());
}
assert_eq!(bytes.split(), Bytes::from_static(b"4\r\ntest\r\n0\r\n\r\n"));
}
#[test]
fn test_extra_headers() {
let mut bytes = BytesMut::with_capacity(2048);
let mut head = RequestHead::default();
head.headers.insert(
AUTHORIZATION,
HeaderValue::from_static("some authorization"),
);
let mut extra_headers = HeaderMap::new();
extra_headers.insert(
AUTHORIZATION,
HeaderValue::from_static("another authorization"),
);
extra_headers.insert(DATE, HeaderValue::from_static("date"));
let head = RequestHeadType::Rc(Rc::new(head), Some(extra_headers));
let _ = head.encode_headers(
&mut bytes,
Version::HTTP_11,
BodySize::Empty,
ConnectionType::Close,
&DateService::default(),
);
let data = String::from_utf8(Vec::from(bytes.split().as_ref())).unwrap();
assert!(data.contains("content-length: 0\r\n"));
assert!(data.contains("connection: close\r\n"));
assert!(data.contains("authorization: another authorization\r\n"));
assert!(data.contains("date: date\r\n"));
}
#[test]
fn test_write_content_length() {
let mut bytes = BytesMut::new();
bytes.reserve(50);
write_content_length(0, &mut bytes);
assert_eq!(bytes.split(), b"\r\ncontent-length: 0\r\n"[..]);
bytes.reserve(50);
write_content_length(9, &mut bytes);
assert_eq!(bytes.split(), b"\r\ncontent-length: 9\r\n"[..]);
bytes.reserve(50);
write_content_length(10, &mut bytes);
assert_eq!(bytes.split(), b"\r\ncontent-length: 10\r\n"[..]);
bytes.reserve(50);
write_content_length(99, &mut bytes);
assert_eq!(bytes.split(), b"\r\ncontent-length: 99\r\n"[..]);
bytes.reserve(50);
write_content_length(100, &mut bytes);
assert_eq!(bytes.split(), b"\r\ncontent-length: 100\r\n"[..]);
bytes.reserve(50);
write_content_length(101, &mut bytes);
assert_eq!(bytes.split(), b"\r\ncontent-length: 101\r\n"[..]);
bytes.reserve(50);
write_content_length(998, &mut bytes);
assert_eq!(bytes.split(), b"\r\ncontent-length: 998\r\n"[..]);
bytes.reserve(50);
write_content_length(1000, &mut bytes);
assert_eq!(bytes.split(), b"\r\ncontent-length: 1000\r\n"[..]);
bytes.reserve(50);
write_content_length(1001, &mut bytes);
assert_eq!(bytes.split(), b"\r\ncontent-length: 1001\r\n"[..]);
bytes.reserve(50);
write_content_length(5909, &mut bytes);
assert_eq!(bytes.split(), b"\r\ncontent-length: 5909\r\n"[..]);
write_content_length(25999, &mut bytes);
assert_eq!(bytes.split(), b"\r\ncontent-length: 25999\r\n"[..]);
}
}