use std::time::{SystemTime, UNIX_EPOCH};
use super::{Headers, Method, Request, Response, StatusCode, Version};
use crate::error::{Error, Result};
#[derive(Debug, Clone, Copy)]
pub struct Limits {
pub max_header_bytes: usize,
pub max_body_bytes: usize,
}
impl Default for Limits {
fn default() -> Limits {
Limits {
max_header_bytes: 64 * 1024,
max_body_bytes: 16 * 1024 * 1024,
}
}
}
#[derive(Debug, Clone, Copy)]
struct Pending {
version: Version,
keep_alive: bool,
is_head: bool,
}
#[derive(Debug)]
pub struct H1Conn {
inbuf: Vec<u8>,
outbuf: Vec<u8>,
limits: Limits,
pending: Option<Pending>,
closed: bool,
interim_sent: bool,
server_name: Option<String>,
}
impl Default for H1Conn {
fn default() -> H1Conn {
H1Conn::new(Limits::default())
}
}
impl H1Conn {
pub fn new(limits: Limits) -> H1Conn {
H1Conn {
inbuf: Vec::new(),
outbuf: Vec::new(),
limits,
pending: None,
closed: false,
interim_sent: false,
server_name: Some(concat!("httpsd/", env!("CARGO_PKG_VERSION")).to_owned()),
}
}
pub fn set_server_name(&mut self, name: Option<String>) {
self.server_name = name;
}
pub fn feed(&mut self, data: &[u8]) {
self.inbuf.extend_from_slice(data);
}
pub fn take_out(&mut self) -> Vec<u8> {
std::mem::take(&mut self.outbuf)
}
pub fn has_output(&self) -> bool {
!self.outbuf.is_empty()
}
pub fn wants_close(&self) -> bool {
self.closed
}
pub fn awaiting_response(&self) -> bool {
self.pending.is_some()
}
pub fn poll_request(&mut self) -> Result<Option<Request>> {
if self.closed || self.pending.is_some() {
return Ok(None);
}
let Some(head_end) = find_subslice(&self.inbuf, b"\r\n\r\n") else {
if self.inbuf.len() > self.limits.max_header_bytes {
return Err(self.fail(StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE, "headers"));
}
return Ok(None);
};
let header_block_len = head_end; if header_block_len > self.limits.max_header_bytes {
return Err(self.fail(StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE, "headers"));
}
let body_start = head_end + 4;
let head = &self.inbuf[..header_block_len];
let (method, target, version, headers) = match parse_head(head) {
Ok(parts) => parts,
Err(e) => {
let status = match &e {
Error::BadRequest(_) => StatusCode::BAD_REQUEST,
_ => StatusCode::BAD_REQUEST,
};
return Err(self.fail(status, "request line/headers"));
}
};
if version.is_none() {
return Err(self.fail(StatusCode::HTTP_VERSION_NOT_SUPPORTED, "version"));
}
let version = version.unwrap();
let framing = match body_framing(&headers) {
Ok(f) => f,
Err(()) => return Err(self.fail(StatusCode::BAD_REQUEST, "body framing")),
};
let body: Vec<u8>;
let consumed_total: usize;
match framing {
BodyFraming::None => {
body = Vec::new();
consumed_total = body_start;
}
BodyFraming::Length(len) => {
if len > self.limits.max_body_bytes {
return Err(self.fail(StatusCode::PAYLOAD_TOO_LARGE, "body"));
}
if self.inbuf.len() < body_start + len {
self.maybe_send_continue(&headers);
return Ok(None);
}
body = self.inbuf[body_start..body_start + len].to_vec();
consumed_total = body_start + len;
}
BodyFraming::Chunked => {
match decode_chunked(&self.inbuf[body_start..], self.limits.max_body_bytes) {
Ok(Some((decoded, used))) => {
body = decoded;
consumed_total = body_start + used;
}
Ok(None) => {
self.maybe_send_continue(&headers);
return Ok(None);
}
Err(status) => return Err(self.fail(status, "chunked body")),
}
}
}
self.inbuf.drain(..consumed_total);
self.interim_sent = false;
let keep_alive = negotiate_keep_alive(version, &headers);
let is_head = method.is_head();
self.pending = Some(Pending {
version,
keep_alive,
is_head,
});
Ok(Some(Request::new(method, target, version, headers, body)))
}
pub fn respond(&mut self, resp: Response) {
let meta = self
.pending
.take()
.expect("respond() called with no request in flight");
self.serialize(meta, resp);
}
fn fail(&mut self, status: StatusCode, what: &'static str) -> Error {
let meta = Pending {
version: Version::Http11,
keep_alive: false,
is_head: false,
};
let resp = Response::status(status);
self.pending = None;
self.serialize(meta, resp);
self.closed = true;
match status.code() {
413 | 431 => Error::TooLarge(what),
_ => Error::BadRequest(what),
}
}
fn maybe_send_continue(&mut self, headers: &Headers) {
if !self.interim_sent && headers.contains_token("expect", "100-continue") {
self.outbuf
.extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n");
self.interim_sent = true;
}
}
fn serialize(&mut self, meta: Pending, resp: Response) {
let (status, mut headers, body) = resp.into_parts();
let bodyless = status.is_bodyless();
let omit_body = bodyless || meta.is_head;
if !bodyless {
headers.set("Content-Length", body.len().to_string());
} else {
headers.remove("Content-Length");
}
let keep_alive = meta.keep_alive && !self.closed;
headers.set(
"Connection",
if keep_alive { "keep-alive" } else { "close" },
);
if let Some(server) = &self.server_name {
headers.set_if_absent("Server", server.clone());
}
headers.set_if_absent("Date", http_date(now_secs()));
let line = format!(
"{} {} {}\r\n",
meta.version.as_str(),
status.code(),
status.reason()
);
self.outbuf.extend_from_slice(line.as_bytes());
for (name, value) in headers.iter() {
self.outbuf.extend_from_slice(name.as_bytes());
self.outbuf.extend_from_slice(b": ");
self.outbuf.extend_from_slice(value.as_bytes());
self.outbuf.extend_from_slice(b"\r\n");
}
self.outbuf.extend_from_slice(b"\r\n");
if !omit_body {
self.outbuf.extend_from_slice(&body);
}
if !keep_alive {
self.closed = true;
}
}
}
enum BodyFraming {
None,
Length(usize),
Chunked,
}
fn body_framing(headers: &Headers) -> std::result::Result<BodyFraming, ()> {
let chunked = headers.contains_token("transfer-encoding", "chunked");
let has_te = headers.contains("transfer-encoding");
let has_cl = headers.contains("content-length");
if has_te && has_cl {
return Err(());
}
if chunked {
return Ok(BodyFraming::Chunked);
}
if has_te {
return Err(());
}
let mut len: Option<usize> = None;
for v in headers.get_all("content-length") {
let parsed: usize = v.trim().parse().map_err(|_| ())?;
match len {
Some(prev) if prev != parsed => return Err(()),
_ => len = Some(parsed),
}
}
match len {
Some(0) | None => Ok(BodyFraming::None),
Some(n) => Ok(BodyFraming::Length(n)),
}
}
fn negotiate_keep_alive(version: Version, headers: &Headers) -> bool {
if headers.contains_token("connection", "close") {
return false;
}
if headers.contains_token("connection", "keep-alive") {
return true;
}
version.default_keep_alive()
}
fn parse_head(head: &[u8]) -> Result<(Method, String, Option<Version>, Headers)> {
let text = std::str::from_utf8(head).map_err(|_| Error::BadRequest("non-UTF-8 header"))?;
let mut lines = text.split("\r\n");
let request_line = lines.next().ok_or(Error::BadRequest("empty request"))?;
let mut parts = request_line.split(' ');
let method = parts.next().ok_or(Error::BadRequest("no method"))?;
let target = parts.next().ok_or(Error::BadRequest("no target"))?;
let version_tok = parts.next().ok_or(Error::BadRequest("no version"))?;
if parts.next().is_some() {
return Err(Error::BadRequest("trailing request-line tokens"));
}
if method.is_empty() || target.is_empty() {
return Err(Error::BadRequest("empty request-line token"));
}
let method = Method::parse(method);
let version = Version::parse(version_tok);
let mut headers = Headers::new();
for line in lines {
if line.is_empty() {
continue;
}
if line.starts_with(' ') || line.starts_with('\t') {
return Err(Error::BadRequest("obsolete header folding"));
}
let (name, value) = line
.split_once(':')
.ok_or(Error::BadRequest("header without colon"))?;
if name.is_empty() || name.contains(' ') {
return Err(Error::BadRequest("invalid header name"));
}
headers.append(name.trim(), value.trim());
}
Ok((method, target.to_owned(), version, headers))
}
fn decode_chunked(
data: &[u8],
max_body: usize,
) -> std::result::Result<Option<(Vec<u8>, usize)>, StatusCode> {
let mut pos = 0usize;
let mut out = Vec::new();
loop {
let Some(eol) = find_subslice(&data[pos..], b"\r\n") else {
return Ok(None);
};
let size_line = &data[pos..pos + eol];
let hex = match size_line.iter().position(|&b| b == b';') {
Some(i) => &size_line[..i],
None => size_line,
};
let hex = std::str::from_utf8(hex).map_err(|_| StatusCode::BAD_REQUEST)?;
let size = usize::from_str_radix(hex.trim(), 16).map_err(|_| StatusCode::BAD_REQUEST)?;
let after_size = pos + eol + 2;
if size == 0 {
let Some(term) = find_subslice(&data[after_size..], b"\r\n") else {
return Ok(None);
};
let consumed = after_size + term + 2;
return Ok(Some((out, consumed)));
}
if out.len() + size > max_body {
return Err(StatusCode::PAYLOAD_TOO_LARGE);
}
if data.len() < after_size + size + 2 {
return Ok(None);
}
out.extend_from_slice(&data[after_size..after_size + size]);
if &data[after_size + size..after_size + size + 2] != b"\r\n" {
return Err(StatusCode::BAD_REQUEST);
}
pos = after_size + size + 2;
}
}
fn find_subslice(haystack: &[u8], needle: &[u8]) -> Option<usize> {
if needle.is_empty() || haystack.len() < needle.len() {
return None;
}
haystack.windows(needle.len()).position(|w| w == needle)
}
fn now_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
pub(crate) fn http_date(secs: u64) -> String {
const WDAY: [&str; 7] = ["Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"];
const MON: [&str; 12] = [
"Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec",
];
let days = (secs / 86_400) as i64;
let tod = secs % 86_400;
let (h, mi, s) = (tod / 3600, (tod % 3600) / 60, tod % 60);
let z = days + 719_468;
let era = if z >= 0 { z } else { z - 146_096 } / 146_097;
let doe = z - era * 146_097; let yoe = (doe - doe / 1460 + doe / 36_524 - doe / 146_096) / 365; let mut year = yoe + era * 400;
let doy = doe - (365 * yoe + yoe / 4 - yoe / 100); let mp = (5 * doy + 2) / 153; let day = doy - (153 * mp + 2) / 5 + 1; let month = if mp < 10 { mp + 3 } else { mp - 9 }; if month <= 2 {
year += 1;
}
let wday = ((days % 7 + 7) % 7 + 4) % 7;
format!(
"{}, {:02} {} {:04} {:02}:{:02}:{:02} GMT",
WDAY[wday as usize],
day,
MON[(month - 1) as usize],
year,
h,
mi,
s,
)
}
#[cfg(test)]
mod tests {
use super::*;
fn drive(conn: &mut H1Conn, input: &[u8]) -> Option<Request> {
conn.feed(input);
conn.poll_request().unwrap()
}
#[test]
fn parses_simple_get() {
let mut c = H1Conn::default();
let req = drive(&mut c, b"GET /hello?x=1 HTTP/1.1\r\nHost: a\r\n\r\n").unwrap();
assert_eq!(req.method(), &Method::Get);
assert_eq!(req.path(), "/hello");
assert_eq!(req.query(), Some("x=1"));
assert_eq!(req.host(), Some("a"));
assert!(req.body().is_empty());
}
#[test]
fn waits_for_full_body() {
let mut c = H1Conn::default();
c.feed(b"POST / HTTP/1.1\r\nContent-Length: 5\r\n\r\nab");
assert!(c.poll_request().unwrap().is_none());
c.feed(b"cde");
let req = c.poll_request().unwrap().unwrap();
assert_eq!(req.body(), b"abcde");
}
#[test]
fn decodes_chunked() {
let mut c = H1Conn::default();
let req = drive(
&mut c,
b"POST / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nhello\r\n0\r\n\r\n",
)
.unwrap();
assert_eq!(req.body(), b"hello");
}
#[test]
fn keep_alive_default_by_version() {
let mut c = H1Conn::default();
let req = drive(&mut c, b"GET / HTTP/1.1\r\n\r\n").unwrap();
assert!(negotiate_keep_alive(req.version(), req.headers()));
let mut c = H1Conn::default();
let req = drive(&mut c, b"GET / HTTP/1.0\r\n\r\n").unwrap();
assert!(!negotiate_keep_alive(req.version(), req.headers()));
}
#[test]
fn serializes_response_with_framing() {
let mut c = H1Conn::default();
let _ = drive(&mut c, b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n").unwrap();
c.respond(Response::text("hi"));
let out = String::from_utf8(c.take_out()).unwrap();
assert!(out.starts_with("HTTP/1.1 200 OK\r\n"));
assert!(out.contains("Content-Length: 2\r\n"));
assert!(out.contains("Connection: close\r\n"));
assert!(out.ends_with("\r\n\r\nhi"));
assert!(c.wants_close());
}
#[test]
fn head_omits_body_keeps_length() {
let mut c = H1Conn::default();
let _ = drive(&mut c, b"HEAD / HTTP/1.1\r\n\r\n").unwrap();
c.respond(Response::text("hello"));
let out = String::from_utf8(c.take_out()).unwrap();
assert!(out.contains("Content-Length: 5\r\n"));
assert!(out.ends_with("\r\n\r\n")); }
#[test]
fn rejects_te_and_cl_together() {
let mut c = H1Conn::default();
c.feed(b"POST / HTTP/1.1\r\nContent-Length: 1\r\nTransfer-Encoding: chunked\r\n\r\n");
assert!(c.poll_request().is_err());
assert!(c.wants_close());
let out = String::from_utf8(c.take_out()).unwrap();
assert!(out.starts_with("HTTP/1.1 400"));
}
#[test]
fn http_date_known_value() {
assert_eq!(http_date(784_111_777), "Sun, 06 Nov 1994 08:49:37 GMT");
}
}