use std::marker::PhantomData;
use super::error::RestError;
use crate::buf::WriteBuf;
pub struct Query;
pub struct Headers;
pub struct Ready;
mod sealed {
pub trait Phase {}
impl Phase for super::Query {}
impl Phase for super::Headers {}
impl Phase for super::Ready {}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Method {
Get,
Post,
Put,
Delete,
Patch,
}
impl Method {
pub fn as_str(self) -> &'static str {
match self {
Self::Get => "GET",
Self::Post => "POST",
Self::Put => "PUT",
Self::Delete => "DELETE",
Self::Patch => "PATCH",
}
}
}
impl std::fmt::Display for Method {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Clone)]
pub struct Request<'a> {
data: &'a [u8],
}
impl<'a> Request<'a> {
pub fn as_bytes(&self) -> &[u8] {
self.data
}
pub fn into_bytes(self) -> &'a [u8] {
self.data
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
}
impl std::fmt::Debug for Request<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Request")
.field("len", &self.data.len())
.finish()
}
}
const UNRESERVED: [bool; 256] = {
let mut table = [false; 256];
let mut i = 0;
while i < 256 {
table[i] = matches!(
i as u8,
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'.' | b'_' | b'~'
);
i += 1;
}
table
};
const HEX_UPPER: &[u8; 16] = b"0123456789ABCDEF";
fn append_percent_encoded(buf: &mut WriteBuf, input: &[u8], error: &mut Option<RestError>) {
if error.is_some() {
return;
}
let mut i = 0;
while i < input.len() {
let run_start = i;
while i < input.len() && UNRESERVED[input[i] as usize] {
i += 1;
}
if i > run_start {
checked_append(buf, &input[run_start..i], error);
if error.is_some() {
return;
}
}
if i < input.len() {
let b = input[i];
checked_append(
buf,
&[
b'%',
HEX_UPPER[(b >> 4) as usize],
HEX_UPPER[(b & 0xf) as usize],
],
error,
);
if error.is_some() {
return;
}
i += 1;
}
}
}
fn checked_append(buf: &mut WriteBuf, src: &[u8], error: &mut Option<RestError>) {
if error.is_some() {
return;
}
if src.len() > buf.tailroom() {
*error = Some(RestError::RequestTooLarge {
capacity: buf.len() + buf.tailroom(),
});
return;
}
buf.append(src);
}
fn has_crlf(s: &str) -> bool {
s.bytes().any(|b| b == b'\r' || b == b'\n')
}
fn write_usize_ascii(buf: &mut WriteBuf, n: usize, error: &mut Option<RestError>) {
if n == 0 {
checked_append(buf, b"0", error);
return;
}
let mut digits = [0u8; 20]; let mut i = 20;
let mut val = n;
while val > 0 {
i -= 1;
digits[i] = (val % 10) as u8 + b'0';
val /= 10;
}
checked_append(buf, &digits[i..], error);
}
fn seal_request_line(writer: &mut RequestWriter, error: &mut Option<RestError>) {
checked_append(&mut writer.write_buf, b" HTTP/1.1\r\n", error);
checked_append(&mut writer.write_buf, &writer.host_wire, error);
if !writer.default_headers_wire.is_empty() {
checked_append(&mut writer.write_buf, &writer.default_headers_wire, error);
}
}
fn write_body(buf: &mut WriteBuf, body: &[u8], error: &mut Option<RestError>) {
checked_append(buf, b"Content-Length: ", error);
write_usize_ascii(buf, body.len(), error);
checked_append(buf, b"\r\n\r\n", error);
checked_append(buf, body, error);
}
fn write_body_header(buf: &mut WriteBuf, body_len: usize, error: &mut Option<RestError>) {
checked_append(buf, b"Content-Length: ", error);
write_usize_ascii(buf, body_len, error);
checked_append(buf, b"\r\n\r\n", error);
}
const CL_PREFIX: &[u8] = b"Content-Length: ";
const CL_PAD_LEN: usize = 20;
const CL_SUFFIX: &[u8] = b"\r\n\r\n";
fn write_content_length_placeholder(buf: &mut WriteBuf, error: &mut Option<RestError>) -> usize {
checked_append(buf, CL_PREFIX, error);
let num_offset = buf.len();
checked_append(buf, b"00000000000000000000", error); checked_append(buf, CL_SUFFIX, error);
num_offset
}
pub type BodyWriter<'a> = crate::buf::WriteBufWriter<'a>;
fn backfill_content_length(buf: &mut WriteBuf, num_offset: usize, body_len: usize) {
let mut digits = [0u8; 20];
let digit_len = if body_len == 0 {
digits[0] = b'0';
1
} else {
let mut val = body_len;
let mut i = 20;
while val > 0 {
i -= 1;
digits[i] = (val % 10) as u8 + b'0';
val /= 10;
}
let len = 20 - i;
digits.copy_within(i..20, 0);
len
};
let gap = CL_PAD_LEN - digit_len;
{
let data = buf.data_mut();
data[num_offset..num_offset + digit_len].copy_from_slice(&digits[..digit_len]);
let suffix_dst = num_offset + digit_len;
data[suffix_dst..suffix_dst + CL_SUFFIX.len()].copy_from_slice(CL_SUFFIX);
if gap > 0 {
let body_start = num_offset + CL_PAD_LEN + CL_SUFFIX.len();
let body_end = data.len();
if body_start < body_end {
data.copy_within(body_start..body_end, body_start - gap);
}
}
}
if gap > 0 {
buf.shrink_tail(gap);
}
}
pub struct RequestWriter {
write_buf: WriteBuf,
host_wire: Vec<u8>,
default_headers_wire: Vec<u8>,
base_path: Vec<u8>,
}
impl RequestWriter {
pub fn new(host: &str) -> Result<Self, RestError> {
if host.bytes().any(|b| b == b'\r' || b == b'\n') {
return Err(RestError::CrlfInjection);
}
let mut host_wire = Vec::with_capacity(host.len() + 32);
host_wire.extend_from_slice(b"Host: ");
host_wire.extend_from_slice(host.as_bytes());
host_wire.extend_from_slice(b"\r\nConnection: keep-alive\r\n");
Ok(Self {
write_buf: WriteBuf::new(32 * 1024, 0),
host_wire,
default_headers_wire: Vec::new(),
base_path: Vec::new(),
})
}
pub fn set_write_buffer_capacity(&mut self, capacity: usize) {
self.write_buf = WriteBuf::new(capacity, 0);
}
pub fn default_header(&mut self, name: &str, value: &str) -> Result<(), RestError> {
if has_crlf(name) || has_crlf(value) {
return Err(RestError::CrlfInjection);
}
self.default_headers_wire.extend_from_slice(name.as_bytes());
self.default_headers_wire.extend_from_slice(b": ");
self.default_headers_wire
.extend_from_slice(value.as_bytes());
self.default_headers_wire.extend_from_slice(b"\r\n");
Ok(())
}
pub fn set_base_path(&mut self, path: &str) -> Result<(), RestError> {
if has_crlf(path) {
return Err(RestError::CrlfInjection);
}
self.base_path = path.trim_end_matches('/').as_bytes().to_vec();
Ok(())
}
pub fn get(&mut self, path: &str) -> RequestBuilder<'_> {
self.request(Method::Get, path)
}
pub fn post(&mut self, path: &str) -> RequestBuilder<'_> {
self.request(Method::Post, path)
}
pub fn put(&mut self, path: &str) -> RequestBuilder<'_> {
self.request(Method::Put, path)
}
pub fn delete(&mut self, path: &str) -> RequestBuilder<'_> {
self.request(Method::Delete, path)
}
pub fn request(&mut self, method: Method, path: &str) -> RequestBuilder<'_> {
RequestBuilder::new(self, method, path)
}
pub fn get_raw(&mut self, path: &str) -> RequestBuilder<'_, Headers> {
self.request_raw(Method::Get, path)
}
pub fn post_raw(&mut self, path: &str) -> RequestBuilder<'_, Headers> {
self.request_raw(Method::Post, path)
}
pub fn put_raw(&mut self, path: &str) -> RequestBuilder<'_, Headers> {
self.request_raw(Method::Put, path)
}
pub fn delete_raw(&mut self, path: &str) -> RequestBuilder<'_, Headers> {
self.request_raw(Method::Delete, path)
}
pub fn request_raw(&mut self, method: Method, path: &str) -> RequestBuilder<'_, Headers> {
RequestBuilder::new_sealed(self, method, path)
}
}
#[must_use = "request must be finished with .finish()"]
pub struct RequestBuilder<'a, P: sealed::Phase = Query> {
writer: &'a mut RequestWriter,
has_query: bool,
error: Option<RestError>,
_phase: PhantomData<P>,
}
impl<'a> RequestBuilder<'a, Query> {
pub(crate) fn new(writer: &'a mut RequestWriter, method: Method, path: &str) -> Self {
writer.write_buf.clear();
let mut error = if has_crlf(path) {
Some(RestError::CrlfInjection)
} else {
None
};
checked_append(
&mut writer.write_buf,
method.as_str().as_bytes(),
&mut error,
);
checked_append(&mut writer.write_buf, b" ", &mut error);
if !writer.base_path.is_empty() {
checked_append(&mut writer.write_buf, &writer.base_path, &mut error);
}
checked_append(&mut writer.write_buf, path.as_bytes(), &mut error);
Self {
writer,
has_query: path.contains('?'),
error,
_phase: PhantomData,
}
}
pub(crate) fn new_sealed(
writer: &'a mut RequestWriter,
method: Method,
path: &str,
) -> RequestBuilder<'a, Headers> {
writer.write_buf.clear();
let mut error = if has_crlf(path) {
Some(RestError::CrlfInjection)
} else {
None
};
checked_append(
&mut writer.write_buf,
method.as_str().as_bytes(),
&mut error,
);
checked_append(&mut writer.write_buf, b" ", &mut error);
if !writer.base_path.is_empty() {
checked_append(&mut writer.write_buf, &writer.base_path, &mut error);
}
checked_append(&mut writer.write_buf, path.as_bytes(), &mut error);
seal_request_line(writer, &mut error);
RequestBuilder {
writer,
has_query: false,
error,
_phase: PhantomData,
}
}
pub fn query(mut self, key: &str, value: &str) -> Self {
let sep = if self.has_query { b"&" as &[u8] } else { b"?" };
checked_append(&mut self.writer.write_buf, sep, &mut self.error);
append_percent_encoded(&mut self.writer.write_buf, key.as_bytes(), &mut self.error);
checked_append(&mut self.writer.write_buf, b"=", &mut self.error);
append_percent_encoded(
&mut self.writer.write_buf,
value.as_bytes(),
&mut self.error,
);
self.has_query = true;
self
}
pub fn query_raw(mut self, key: &str, value: &str) -> Self {
if has_crlf(key) || has_crlf(value) {
self.error = Some(RestError::CrlfInjection);
return self;
}
let sep = if self.has_query { b"&" as &[u8] } else { b"?" };
checked_append(&mut self.writer.write_buf, sep, &mut self.error);
checked_append(&mut self.writer.write_buf, key.as_bytes(), &mut self.error);
checked_append(&mut self.writer.write_buf, b"=", &mut self.error);
checked_append(
&mut self.writer.write_buf,
value.as_bytes(),
&mut self.error,
);
self.has_query = true;
self
}
pub fn header(mut self, name: &str, value: &str) -> RequestBuilder<'a, Headers> {
seal_request_line(self.writer, &mut self.error);
let mut next = RequestBuilder {
writer: self.writer,
has_query: self.has_query,
error: self.error,
_phase: PhantomData,
};
next.append_header(name, value);
next
}
pub fn body(mut self, body: &[u8]) -> RequestBuilder<'a, Ready> {
seal_request_line(self.writer, &mut self.error);
write_body(&mut self.writer.write_buf, body, &mut self.error);
RequestBuilder {
writer: self.writer,
has_query: self.has_query,
error: self.error,
_phase: PhantomData,
}
}
pub fn body_writer<F, E>(mut self, f: F) -> RequestBuilder<'a, Ready>
where
F: FnOnce(&mut BodyWriter<'_>) -> Result<(), E>,
E: Into<Box<dyn std::error::Error + Send + Sync>>,
{
seal_request_line(self.writer, &mut self.error);
if self.error.is_some() {
return RequestBuilder {
writer: self.writer,
has_query: self.has_query,
error: self.error,
_phase: PhantomData,
};
}
let num_offset =
write_content_length_placeholder(&mut self.writer.write_buf, &mut self.error);
if self.error.is_some() {
return RequestBuilder {
writer: self.writer,
has_query: self.has_query,
error: self.error,
_phase: PhantomData,
};
}
let body_len = {
let mut bw = BodyWriter::new(&mut self.writer.write_buf);
if let Err(e) = f(&mut bw) {
self.error = Some(if self.writer.write_buf.tailroom() == 0 {
RestError::RequestTooLarge {
capacity: self.writer.write_buf.len() + self.writer.write_buf.tailroom(),
}
} else {
RestError::Io(std::io::Error::other(e))
});
0
} else {
bw.written()
}
}; if self.error.is_none() {
backfill_content_length(&mut self.writer.write_buf, num_offset, body_len);
}
RequestBuilder {
writer: self.writer,
has_query: self.has_query,
error: self.error,
_phase: PhantomData,
}
}
pub fn body_fixed(
mut self,
len: usize,
f: impl FnOnce(&mut [u8]),
) -> RequestBuilder<'a, Ready> {
seal_request_line(self.writer, &mut self.error);
write_body_header(&mut self.writer.write_buf, len, &mut self.error);
if self.error.is_some() {
return RequestBuilder {
writer: self.writer,
has_query: self.has_query,
error: self.error,
_phase: PhantomData,
};
}
let buf = &mut self.writer.write_buf;
if len > buf.tailroom() {
self.error = Some(RestError::RequestTooLarge {
capacity: buf.len() + buf.tailroom(),
});
} else {
let start = buf.len();
buf.extend_zeroed(len);
let data = buf.data_mut();
f(&mut data[start..start + len]);
}
RequestBuilder {
writer: self.writer,
has_query: self.has_query,
error: self.error,
_phase: PhantomData,
}
}
pub fn finish(mut self) -> Result<Request<'a>, RestError> {
seal_request_line(self.writer, &mut self.error);
checked_append(&mut self.writer.write_buf, b"\r\n", &mut self.error);
if let Some(e) = self.error {
return Err(e);
}
Ok(Request {
data: self.writer.write_buf.data(),
})
}
}
impl<'a> RequestBuilder<'a, Headers> {
pub fn header(mut self, name: &str, value: &str) -> Self {
self.append_header(name, value);
self
}
pub fn body(mut self, body: &[u8]) -> RequestBuilder<'a, Ready> {
write_body(&mut self.writer.write_buf, body, &mut self.error);
RequestBuilder {
writer: self.writer,
has_query: self.has_query,
error: self.error,
_phase: PhantomData,
}
}
pub fn body_writer<F, E>(mut self, f: F) -> RequestBuilder<'a, Ready>
where
F: FnOnce(&mut BodyWriter<'_>) -> Result<(), E>,
E: Into<Box<dyn std::error::Error + Send + Sync>>,
{
if self.error.is_some() {
return RequestBuilder {
writer: self.writer,
has_query: self.has_query,
error: self.error,
_phase: PhantomData,
};
}
let num_offset =
write_content_length_placeholder(&mut self.writer.write_buf, &mut self.error);
if self.error.is_some() {
return RequestBuilder {
writer: self.writer,
has_query: self.has_query,
error: self.error,
_phase: PhantomData,
};
}
let body_len = {
let mut bw = BodyWriter::new(&mut self.writer.write_buf);
if let Err(e) = f(&mut bw) {
self.error = Some(if self.writer.write_buf.tailroom() == 0 {
RestError::RequestTooLarge {
capacity: self.writer.write_buf.len() + self.writer.write_buf.tailroom(),
}
} else {
RestError::Io(std::io::Error::other(e))
});
0
} else {
bw.written()
}
}; if self.error.is_none() {
backfill_content_length(&mut self.writer.write_buf, num_offset, body_len);
}
RequestBuilder {
writer: self.writer,
has_query: self.has_query,
error: self.error,
_phase: PhantomData,
}
}
pub fn body_fixed(
mut self,
len: usize,
f: impl FnOnce(&mut [u8]),
) -> RequestBuilder<'a, Ready> {
write_body_header(&mut self.writer.write_buf, len, &mut self.error);
if self.error.is_some() {
return RequestBuilder {
writer: self.writer,
has_query: self.has_query,
error: self.error,
_phase: PhantomData,
};
}
let buf = &mut self.writer.write_buf;
if len > buf.tailroom() {
self.error = Some(RestError::RequestTooLarge {
capacity: buf.len() + buf.tailroom(),
});
} else {
let start = buf.len();
buf.extend_zeroed(len);
let data = buf.data_mut();
f(&mut data[start..start + len]);
}
RequestBuilder {
writer: self.writer,
has_query: self.has_query,
error: self.error,
_phase: PhantomData,
}
}
pub fn finish(mut self) -> Result<Request<'a>, RestError> {
checked_append(&mut self.writer.write_buf, b"\r\n", &mut self.error);
if let Some(e) = self.error {
return Err(e);
}
Ok(Request {
data: self.writer.write_buf.data(),
})
}
fn append_header(&mut self, name: &str, value: &str) {
if self.error.is_some() {
return;
}
if has_crlf(name) || has_crlf(value) {
self.error = Some(RestError::CrlfInjection);
return;
}
checked_append(&mut self.writer.write_buf, name.as_bytes(), &mut self.error);
checked_append(&mut self.writer.write_buf, b": ", &mut self.error);
checked_append(
&mut self.writer.write_buf,
value.as_bytes(),
&mut self.error,
);
checked_append(&mut self.writer.write_buf, b"\r\n", &mut self.error);
}
}
impl<'a> RequestBuilder<'a, Ready> {
pub fn finish(self) -> Result<Request<'a>, RestError> {
if let Some(e) = self.error {
return Err(e);
}
Ok(Request {
data: self.writer.write_buf.data(),
})
}
}