eternal 0.3.2

The Eternal Framework
Documentation
use std::io::Write;
use std::marker::PhantomData;
use std::ptr::copy_nonoverlapping;
use std::slice::from_raw_parts_mut;
use std::{cmp, io};

use bytes::{buf::BufMutExt, BufMut, BytesMut};

use crate::http::body::BodySize;
use crate::http::config::ServiceConfig;
use crate::http::header::map;
use crate::http::helpers;
use crate::http::header::{CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING};
use crate::http::{HeaderMap, StatusCode, Version};
use crate::http::message::{ConnectionType, RequestHeadType};
use crate::http::response::Response;

const AVERAGE_HEADER_SIZE: usize = 30;

#[derive(Debug)]
pub(crate) struct MessageEncoder<T: MessageType> {
    pub length: BodySize,
    pub te: TransferEncoding,
    _t: PhantomData<T>,
}

impl<T: MessageType> Default for MessageEncoder<T> {
    fn default() -> Self {
        MessageEncoder {
            length: BodySize::None,
            te: TransferEncoding::empty(),
            _t: PhantomData,
        }
    }
}

pub(crate) trait MessageType: Sized {
    fn status(&self) -> Option<StatusCode>;

    fn headers(&self) -> &HeaderMap;

    fn extra_headers(&self) -> Option<&HeaderMap>;

    fn camel_case(&self) -> bool {
        false
    }

    fn chunked(&self) -> bool;

    fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()>;

    fn encode_headers(
        &mut self,
        dst: &mut BytesMut,
        version: Version,
        mut length: BodySize,
        ctype: ConnectionType,
        config: &ServiceConfig,
    ) -> io::Result<()> {
        let chunked = self.chunked();
        let mut skip_len = length != BodySize::Stream;
        let camel_case = self.camel_case();

        // Content length
        if let Some(status) = self.status() {
            match status {
                StatusCode::NO_CONTENT
                | StatusCode::CONTINUE
                | StatusCode::PROCESSING => length = BodySize::None,
                StatusCode::SWITCHING_PROTOCOLS => {
                    skip_len = true;
                    length = BodySize::Stream;
                }
                _ => (),
            }
        }
        match length {
            BodySize::Stream => {
                if chunked {
                    if camel_case {
                        dst.put_slice(b"\r\nTransfer-Encoding: chunked\r\n")
                    } else {
                        dst.put_slice(b"\r\ntransfer-encoding: chunked\r\n")
                    }
                } else {
                    skip_len = false;
                    dst.put_slice(b"\r\n");
                }
            }
            BodySize::Empty => {
                if camel_case {
                    dst.put_slice(b"\r\nContent-Length: 0\r\n");
                } else {
                    dst.put_slice(b"\r\ncontent-length: 0\r\n");
                }
            }
            BodySize::Sized(len) => helpers::write_content_length(len, dst),
            BodySize::Sized64(len) => {
                if camel_case {
                    dst.put_slice(b"\r\nContent-Length: ");
                } else {
                    dst.put_slice(b"\r\ncontent-length: ");
                }
                #[allow(clippy::write_with_newline)]
                write!(dst.writer(), "{}\r\n", len)?;
            }
            BodySize::None => dst.put_slice(b"\r\n"),
        }

        // Connection
        match ctype {
            ConnectionType::Upgrade => dst.put_slice(b"connection: upgrade\r\n"),
            ConnectionType::KeepAlive if version < Version::HTTP_11 => {
                if camel_case {
                    dst.put_slice(b"Connection: keep-alive\r\n")
                } else {
                    dst.put_slice(b"connection: keep-alive\r\n")
                }
            }
            ConnectionType::Close if version >= Version::HTTP_11 => {
                if camel_case {
                    dst.put_slice(b"Connection: close\r\n")
                } else {
                    dst.put_slice(b"connection: close\r\n")
                }
            }
            _ => (),
        }

        // merging headers from head and extra headers. HeaderMap::new() does not allocate.
        let empty_headers = HeaderMap::new();
        let extra_headers = self.extra_headers().unwrap_or(&empty_headers);
        let headers = self
            .headers()
            .inner
            .iter()
            .filter(|(name, _)| !extra_headers.contains_key(*name))
            .chain(extra_headers.inner.iter());

        // write headers
        let mut pos = 0;
        let mut has_date = false;
        let mut remaining = dst.capacity() - dst.len();
        let mut buf = dst.bytes_mut().as_mut_ptr() as *mut u8;
        for (key, value) in headers {
            match *key {
                CONNECTION => continue,
                TRANSFER_ENCODING | CONTENT_LENGTH if skip_len => continue,
                DATE => {
                    has_date = true;
                }
                _ => (),
            }
            let k = key.as_str().as_bytes();
            match value {
                map::Value::One(ref val) => {
                    let v = val.as_ref();
                    let v_len = v.len();
                    let k_len = k.len();
                    let len = k_len + v_len + 4;
                    if len > remaining {
                        unsafe {
                            dst.advance_mut(pos);
                        }
                        pos = 0;
                        dst.reserve(len * 2);
                        remaining = dst.capacity() - dst.len();
                        buf = dst.bytes_mut().as_mut_ptr() as *mut u8;
                    }
                    // use upper Camel-Case
                    unsafe {
                        if camel_case {
                            write_camel_case(k, from_raw_parts_mut(buf, k_len))
                        } else {
                            write_data(k, buf, k_len)
                        }
                        buf = buf.add(k_len);
                        write_data(b": ", buf, 2);
                        buf = buf.add(2);
                        write_data(v, buf, v_len);
                        buf = buf.add(v_len);
                        write_data(b"\r\n", buf, 2);
                        buf = buf.add(2);
                        pos += len;
                        remaining -= len;
                    }
                }
                map::Value::Multi(ref vec) => {
                    for val in vec {
                        let v = val.as_ref();
                        let v_len = v.len();
                        let k_len = k.len();
                        let len = k_len + v_len + 4;
                        if len > remaining {
                            unsafe {
                                dst.advance_mut(pos);
                            }
                            pos = 0;
                            dst.reserve(len * 2);
                            remaining = dst.capacity() - dst.len();
                            buf = dst.bytes_mut().as_mut_ptr() as *mut u8;
                        }
                        // use upper Camel-Case
                        unsafe {
                            if camel_case {
                                write_camel_case(k, from_raw_parts_mut(buf, k_len));
                            } else {
                                write_data(k, buf, k_len);
                            }
                            buf = buf.add(k_len);
                            write_data(b": ", buf, 2);
                            buf = buf.add(2);
                            write_data(v, buf, v_len);
                            buf = buf.add(v_len);
                            write_data(b"\r\n", buf, 2);
                            buf = buf.add(2);
                        };
                        pos += len;
                        remaining -= len;
                    }
                }
            }
        }
        unsafe {
            dst.advance_mut(pos);
        }

        // optimized date header, set_date writes \r\n
        if !has_date {
            config.set_date(dst);
        } else {
            // msg eof
            dst.extend_from_slice(b"\r\n");
        }

        Ok(())
    }
}

impl MessageType for Response<()> {
    fn status(&self) -> Option<StatusCode> {
        Some(self.head().status)
    }

    fn chunked(&self) -> bool {
        self.head().chunked()
    }

    fn headers(&self) -> &HeaderMap {
        &self.head().headers
    }

    fn extra_headers(&self) -> Option<&HeaderMap> {
        None
    }

    fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> {
        let head = self.head();
        let reason = head.reason().as_bytes();
        dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE + reason.len());

        // status line
        helpers::write_status_line(head.version, head.status.as_u16(), dst);
        dst.put_slice(reason);
        Ok(())
    }
}

impl MessageType for RequestHeadType {
    fn status(&self) -> Option<StatusCode> {
        None
    }

    fn chunked(&self) -> bool {
        self.as_ref().chunked()
    }

    fn camel_case(&self) -> bool {
        self.as_ref().camel_case_headers()
    }

    fn headers(&self) -> &HeaderMap {
        self.as_ref().headers()
    }

    fn extra_headers(&self) -> Option<&HeaderMap> {
        self.extra_headers()
    }

    fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> {
        let head = self.as_ref();
        dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE);
        write!(
            Writer(dst),
            "{} {} {}",
            head.method,
            head.uri.path_and_query().map(|u| u.as_str()).unwrap_or("/"),
            match head.version {
                Version::HTTP_09 => "HTTP/0.9",
                Version::HTTP_10 => "HTTP/1.0",
                Version::HTTP_11 => "HTTP/1.1",
                Version::HTTP_2 => "HTTP/2.0",
                Version::HTTP_3 => "HTTP/3.0",
                _ =>
                    return Err(io::Error::new(
                        io::ErrorKind::Other,
                        "unsupported version"
                    )),
            }
        )
        .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
    }
}

impl<T: MessageType> MessageEncoder<T> {
    /// Encode message
    pub fn encode_chunk(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result<bool> {
        self.te.encode(msg, buf)
    }

    /// Encode eof
    pub fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> {
        self.te.encode_eof(buf)
    }

    pub fn encode(
        &mut self,
        dst: &mut BytesMut,
        message: &mut T,
        head: bool,
        stream: bool,
        version: Version,
        length: BodySize,
        ctype: ConnectionType,
        config: &ServiceConfig,
    ) -> io::Result<()> {
        // transfer encoding
        if !head {
            self.te = match length {
                BodySize::Empty => TransferEncoding::empty(),
                BodySize::Sized(len) => TransferEncoding::length(len as u64),
                BodySize::Sized64(len) => TransferEncoding::length(len),
                BodySize::Stream => {
                    if message.chunked() && !stream {
                        TransferEncoding::chunked()
                    } else {
                        TransferEncoding::eof()
                    }
                }
                BodySize::None => TransferEncoding::empty(),
            };
        } else {
            self.te = TransferEncoding::empty();
        }

        message.encode_status(dst)?;
        message.encode_headers(dst, version, length, ctype, config)
    }
}

/// Encoders to handle different Transfer-Encodings.
#[derive(Debug)]
pub(crate) struct TransferEncoding {
    kind: TransferEncodingKind,
}

#[derive(Debug, PartialEq, Clone)]
enum TransferEncodingKind {
    /// An Encoder for when Transfer-Encoding includes `chunked`.
    Chunked(bool),
    /// An Encoder for when Content-Length is set.
    ///
    /// Enforces that the body is not longer than the Content-Length header.
    Length(u64),
    /// An Encoder for when Content-Length is not known.
    ///
    /// Application decides when to stop writing.
    Eof,
}

impl TransferEncoding {
    #[inline]
    pub fn empty() -> TransferEncoding {
        TransferEncoding {
            kind: TransferEncodingKind::Length(0),
        }
    }

    #[inline]
    pub fn eof() -> TransferEncoding {
        TransferEncoding {
            kind: TransferEncodingKind::Eof,
        }
    }

    #[inline]
    pub fn chunked() -> TransferEncoding {
        TransferEncoding {
            kind: TransferEncodingKind::Chunked(false),
        }
    }

    #[inline]
    pub fn length(len: u64) -> TransferEncoding {
        TransferEncoding {
            kind: TransferEncodingKind::Length(len),
        }
    }

    /// Encode message. Return `EOF` state of encoder
    #[inline]
    pub fn encode(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result<bool> {
        match self.kind {
            TransferEncodingKind::Eof => {
                let eof = msg.is_empty();
                buf.extend_from_slice(msg);
                Ok(eof)
            }
            TransferEncodingKind::Chunked(ref mut eof) => {
                if *eof {
                    return Ok(true);
                }

                if msg.is_empty() {
                    *eof = true;
                    buf.extend_from_slice(b"0\r\n\r\n");
                } else {
                    writeln!(Writer(buf), "{:X}\r", msg.len())
                        .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;

                    buf.reserve(msg.len() + 2);
                    buf.extend_from_slice(msg);
                    buf.extend_from_slice(b"\r\n");
                }
                Ok(*eof)
            }
            TransferEncodingKind::Length(ref mut remaining) => {
                if *remaining > 0 {
                    if msg.is_empty() {
                        return Ok(*remaining == 0);
                    }
                    let len = cmp::min(*remaining, msg.len() as u64);

                    buf.extend_from_slice(&msg[..len as usize]);

                    *remaining -= len as u64;
                    Ok(*remaining == 0)
                } else {
                    Ok(true)
                }
            }
        }
    }

    /// Encode eof. Return `EOF` state of encoder
    #[inline]
    pub fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> {
        match self.kind {
            TransferEncodingKind::Eof => Ok(()),
            TransferEncodingKind::Length(rem) => {
                if rem != 0 {
                    Err(io::Error::new(io::ErrorKind::UnexpectedEof, ""))
                } else {
                    Ok(())
                }
            }
            TransferEncodingKind::Chunked(ref mut eof) => {
                if !*eof {
                    *eof = true;
                    buf.extend_from_slice(b"0\r\n\r\n");
                }
                Ok(())
            }
        }
    }
}

struct Writer<'a>(pub &'a mut BytesMut);

impl<'a> io::Write for Writer<'a> {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        self.0.extend_from_slice(buf);
        Ok(buf.len())
    }
    fn flush(&mut self) -> io::Result<()> {
        Ok(())
    }
}

unsafe fn write_data(value: &[u8], buf: *mut u8, len: usize) {
    copy_nonoverlapping(value.as_ptr(), buf, len);
}

fn write_camel_case(value: &[u8], buffer: &mut [u8]) {
    let mut index = 0;
    let key = value;
    let mut key_iter = key.iter();

    if let Some(c) = key_iter.next() {
        if *c >= b'a' && *c <= b'z' {
            buffer[index] = *c ^ b' ';
            index += 1;
        }
    } else {
        return;
    }

    while let Some(c) = key_iter.next() {
        buffer[index] = *c;
        index += 1;
        if *c == b'-' {
            if let Some(c) = key_iter.next() {
                if *c >= b'a' && *c <= b'z' {
                    buffer[index] = *c ^ b' ';
                    index += 1;
                }
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use std::rc::Rc;

    use bytes::Bytes;
    use http::header::AUTHORIZATION;

    use super::*;
    use crate::http::header::{HeaderValue, CONTENT_TYPE};
    use crate::http::RequestHead;

    #[test]
    fn test_chunked_te() {
        let mut bytes = BytesMut::new();
        let mut enc = TransferEncoding::chunked();
        {
            assert!(!enc.encode(b"test", &mut bytes).ok().unwrap());
            assert!(enc.encode(b"", &mut bytes).ok().unwrap());
        }
        assert_eq!(
            bytes.split().freeze(),
            Bytes::from_static(b"4\r\ntest\r\n0\r\n\r\n")
        );
    }

    #[test]
    fn test_camel_case() {
        let mut bytes = BytesMut::with_capacity(2048);
        let mut head = RequestHead::default();
        head.set_camel_case_headers(true);
        head.headers.insert(DATE, HeaderValue::from_static("date"));
        head.headers
            .insert(CONTENT_TYPE, HeaderValue::from_static("plain/text"));

        let mut head = RequestHeadType::Owned(head);

        let _ = head.encode_headers(
            &mut bytes,
            Version::HTTP_11,
            BodySize::Empty,
            ConnectionType::Close,
            &ServiceConfig::default(),
        );
        let data =
            String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
        assert!(data.contains("Content-Length: 0\r\n"));
        assert!(data.contains("Connection: close\r\n"));
        assert!(data.contains("Content-Type: plain/text\r\n"));
        assert!(data.contains("Date: date\r\n"));

        let _ = head.encode_headers(
            &mut bytes,
            Version::HTTP_11,
            BodySize::Stream,
            ConnectionType::KeepAlive,
            &ServiceConfig::default(),
        );
        let data =
            String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
        assert!(data.contains("Transfer-Encoding: chunked\r\n"));
        assert!(data.contains("Content-Type: plain/text\r\n"));
        assert!(data.contains("Date: date\r\n"));

        let _ = head.encode_headers(
            &mut bytes,
            Version::HTTP_11,
            BodySize::Sized64(100),
            ConnectionType::KeepAlive,
            &ServiceConfig::default(),
        );
        let data =
            String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
        assert!(data.contains("Content-Length: 100\r\n"));
        assert!(data.contains("Content-Type: plain/text\r\n"));
        assert!(data.contains("Date: date\r\n"));

        let mut head = RequestHead::default();
        head.set_camel_case_headers(false);
        head.headers.insert(DATE, HeaderValue::from_static("date"));
        head.headers
            .insert(CONTENT_TYPE, HeaderValue::from_static("plain/text"));
        head.headers
            .append(CONTENT_TYPE, HeaderValue::from_static("xml"));

        let mut head = RequestHeadType::Owned(head);
        let _ = head.encode_headers(
            &mut bytes,
            Version::HTTP_11,
            BodySize::Stream,
            ConnectionType::KeepAlive,
            &ServiceConfig::default(),
        );
        let data =
            String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
        assert!(data.contains("transfer-encoding: chunked\r\n"));
        assert!(data.contains("content-type: xml\r\n"));
        assert!(data.contains("content-type: plain/text\r\n"));
        assert!(data.contains("date: date\r\n"));
    }

    #[test]
    fn test_extra_headers() {
        let mut bytes = BytesMut::with_capacity(2048);

        let mut head = RequestHead::default();
        head.headers.insert(
            AUTHORIZATION,
            HeaderValue::from_static("some authorization"),
        );

        let mut extra_headers = HeaderMap::new();
        extra_headers.insert(
            AUTHORIZATION,
            HeaderValue::from_static("another authorization"),
        );
        extra_headers.insert(DATE, HeaderValue::from_static("date"));

        let mut head = RequestHeadType::Rc(Rc::new(head), Some(extra_headers));

        let _ = head.encode_headers(
            &mut bytes,
            Version::HTTP_11,
            BodySize::Empty,
            ConnectionType::Close,
            &ServiceConfig::default(),
        );
        let data =
            String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
        assert!(data.contains("content-length: 0\r\n"));
        assert!(data.contains("connection: close\r\n"));
        assert!(data.contains("authorization: another authorization\r\n"));
        assert!(data.contains("date: date\r\n"));
    }
}