mod method;
pub use method::Method;
mod path;
pub use path::Path;
mod query;
pub use query::QueryParams;
mod headers;
#[allow(unused)]
pub use headers::Header as RequestHeader;
pub use headers::Headers as RequestHeaders;
mod context;
pub use context::Context;
mod from_request;
pub use from_request::FromRequest;
#[cfg(test)]
mod _test_extract;
#[cfg(test)]
mod _test_headers;
#[cfg(test)]
mod _test_parse;
pub use ohkami_lib::CowSlice;
#[cfg(feature = "__rt__")]
use ohkami_lib::Slice;
#[cfg(feature = "__rt_native__")]
use crate::__rt__::AsyncRead;
#[allow(unused)]
use {byte_reader::Reader, std::borrow::Cow, std::pin::Pin};
pub struct Request {
#[cfg(feature = "__rt_native__")]
pub(super) __buf__: Box<[u8]>,
#[cfg(feature = "rt_worker")]
pub(super) __url__: std::mem::MaybeUninit<::worker::Url>,
#[cfg(feature = "rt_lambda")]
pub(super) __query__: std::mem::MaybeUninit<Box<str>>,
pub method: Method,
pub path: Path,
pub query: QueryParams,
pub headers: RequestHeaders,
pub payload: Option<CowSlice>,
pub context: Context,
pub ip: std::net::IpAddr,
}
impl Request {
#[cfg(feature = "__rt__")]
#[inline]
fn get_payload_size(
&self,
#[cfg(feature = "__rt_native__")] config: &crate::Config,
) -> Result<Option<std::num::NonZeroUsize>, crate::Response> {
use crate::Response;
let Some(size) = self
.headers
.content_length()
.map(|s| s.parse().map_err(|_| Response::BadRequest()))
.transpose()?
.and_then(std::num::NonZeroUsize::new)
else {
return Ok(None);
};
if matches!(self.method, Method::GET | Method::HEAD | Method::OPTIONS) {
return Err(Response::BadRequest());
}
#[cfg(feature = "__rt_native__")]
if size.get() > config.request_payload_limit {
return Err(Response::PayloadTooLarge());
}
Ok(Some(size))
}
#[cfg(feature = "__rt__")]
#[inline]
pub(crate) fn uninit(
#[cfg(feature = "__rt_native__")] ip: std::net::IpAddr,
#[cfg(feature = "__rt_native__")] config: &crate::Config,
) -> Self {
Self {
#[cfg(feature = "__rt_native__")]
ip,
#[cfg(any(feature = "rt_worker", feature = "rt_lambda"))]
ip: crate::util::IP_0000,
#[cfg(feature = "__rt_native__")]
__buf__: vec![0u8; config.request_bufsize].into_boxed_slice(),
#[cfg(feature = "rt_worker")]
__url__: std::mem::MaybeUninit::uninit(),
#[cfg(feature = "rt_lambda")]
__query__: std::mem::MaybeUninit::uninit(),
method: Method::GET,
path: Path::uninit(),
query: QueryParams::new(b""),
headers: RequestHeaders::new(),
payload: None,
context: Context::init(),
}
}
#[cfg(feature = "__rt_native__")]
#[inline(always)]
pub(crate) fn clear(&mut self) {
if self.__buf__[0] != 0 {
for b in &mut *self.__buf__ {
match b {
0 => break,
_ => *b = 0,
}
}
self.path = Path::uninit();
self.query = QueryParams::new(b"");
self.headers.clear();
self.payload = None;
self.context.clear();
}
}
#[cfg(feature = "__rt_native__")]
pub(crate) async fn read(
mut self: Pin<&mut Self>,
stream: &mut (impl AsyncRead + Unpin),
config: &crate::Config,
) -> Result<Option<()>, crate::Response> {
use crate::Response;
match stream.read(&mut self.__buf__).await {
Ok(0) => return Ok(None),
Err(e) => {
return match e.kind() {
std::io::ErrorKind::ConnectionReset => Ok(None),
_ => Err({
crate::WARNING!("Failed to read stream: {e}");
Response::InternalServerError()
}),
};
}
_ => (),
}
let mut r = Reader::new(unsafe {
Slice::from_bytes(&self.__buf__).as_bytes()
});
match Method::from_bytes(r.read_while(|b| b != &b' ')) {
None => return Ok(None),
Some(method) => self.method = method,
}
r.next_if(|b| *b == b' ').ok_or_else(Response::BadRequest)?;
self.path
.init_with_request_bytes(r.read_while(|b| !matches!(b, b' ' | b'?')))?;
if r.consume_oneof([" ", "?"]).unwrap() == 1 {
self.query = QueryParams::new(r.read_while(|b| b != &b' '));
r.advance_by(1);
}
r.consume("HTTP/1.1\r\n")
.ok_or_else(Response::HTTPVersionNotSupported)?;
while r.consume("\r\n").is_none() {
let key_bytes = r.read_while(|b| b != &b':');
r.consume(": ").ok_or_else(|| {
crate::WARNING!(
"\
[Request::read] Unexpected end of headers! \
Maybe request buffer size is not enough. \
Try setting `request_bufsize` of Config, \
or `OHKAMI_REQUEST_BUFSIZE` environment variable, \
to a larger value (default: {}).\
",
crate::Config::default().request_bufsize
);
Response::RequestHeaderFieldsTooLarge()
})?;
let value = CowSlice::Ref(Slice::from_bytes(r.read_while(|b| b != &b'\r')));
r.consume("\r\n").ok_or_else(|| {
crate::WARNING!(
"\
[Request::read] Unexpected end of headers! \
Maybe request buffer size is not enough. \
Try setting `request_bufsize` of Config, \
or `OHKAMI_REQUEST_BUFSIZE` environment variable, \
to a larger value (default: {}).\
",
crate::Config::default().request_bufsize
);
Response::RequestHeaderFieldsTooLarge()
})?;
if let Some(key) = RequestHeader::from_bytes(key_bytes) {
self.headers.append(key, value);
} else {
self.headers
.insert_custom(Slice::from_bytes(key_bytes), value)
}
}
if let Some(payload_size) = self.get_payload_size(config)? {
self.payload =
Some(Request::read_payload(stream, r.remaining(), payload_size.get()).await?);
}
Ok(Some(()))
}
#[cfg(feature = "__rt_native__")]
#[inline]
async fn read_payload(
stream: &mut (impl AsyncRead + Unpin),
remaining_buf: &[u8],
size: usize,
) -> Result<CowSlice, crate::Response> {
let remaining_buf_len = remaining_buf.len();
if remaining_buf_len == 0 || *unsafe { remaining_buf.get_unchecked(0) } == 0 {
crate::DEBUG!(
"\n[read_payload] case: remaining_buf.is_empty() || remaining_buf[0] == 0\n"
);
let mut bytes = vec![0; size].into_boxed_slice();
if let Err(err) = stream.read_exact(&mut bytes).await {
crate::ERROR!("[Request::read_payload] Failed to read payload from stream: {err}");
return Err(crate::Response::BadRequest());
}
Ok(CowSlice::Own(bytes))
} else if size <= remaining_buf_len {
crate::DEBUG!("\n[read_payload] case: starts_at + size <= BUF_SIZE\n");
#[allow(unused_unsafe/* I don't know why but rustc sometimes put warnings to this unsafe as unnecessary */)]
Ok(CowSlice::Ref(unsafe {
Slice::new_unchecked(remaining_buf.as_ptr(), size)
}))
} else {
crate::DEBUG!("\n[read_payload] case: else\n");
let mut bytes = vec![0; size].into_boxed_slice();
let read_result = unsafe {
bytes
.get_unchecked_mut(..remaining_buf_len)
.copy_from_slice(remaining_buf);
stream
.read_exact(bytes.get_unchecked_mut(remaining_buf_len..))
.await
};
if let Err(err) = read_result {
crate::ERROR!("[Request::read_payload] Failed to read payload from stream: {err}");
return Err(crate::Response::BadRequest());
}
Ok(CowSlice::Own(bytes))
}
}
#[cfg(any(feature = "rt_worker", feature = "rt_lambda"))]
#[cfg(debug_assertions/* for `ohkami::testing` */)]
pub(crate) async fn read(
mut self: Pin<&mut Self>,
raw_bytes: &mut &[u8],
_: &crate::Config,
) -> Result<Option<()>, crate::Response> {
use crate::Response;
self.ip = crate::util::IP_0000;
let mut r = Reader::new(raw_bytes);
match Method::from_bytes(r.read_while(|b| b != &b' ')) {
None => return Ok(None),
Some(method) => self.method = method,
}
r.next_if(|b| *b == b' ').ok_or_else(Response::BadRequest)?;
#[cfg(feature = "rt_worker")]
{
self.__url__.write({
let mut url = String::from("http://test.ohkami");
url.push_str(std::str::from_utf8(r.read_while(|b| b != &b' ')).unwrap());
::worker::Url::parse(&url).unwrap()
});
unsafe {
let __url__ = self.__url__.assume_init_ref();
let path = Slice::from_bytes(__url__.path().as_bytes()).as_bytes();
self.query = QueryParams::new(__url__.query().unwrap_or_default().as_bytes());
self.path.init_with_request_bytes(path)?;
}
}
#[cfg(feature = "rt_lambda")]
{
let path_bytes = r.read_while(|b| b != &b' ' && b != &b'?');
self.path.init_with_request_bytes(path_bytes)?;
if r.next_if(|b| *b == b'?').is_some() {
self.__query__.write(
std::str::from_utf8(r.read_while(|b| b != &b' '))
.unwrap()
.to_owned()
.into_boxed_str(),
);
unsafe {
self.query = QueryParams::new(self.__query__.assume_init_ref().as_bytes());
}
}
r.next_if(|b| *b == b' ').ok_or_else(Response::BadRequest)?;
}
r.consume("HTTP/1.1\r\n")
.ok_or_else(Response::HTTPVersionNotSupported)?;
while r.consume("\r\n").is_none() {
let key_bytes = r.read_while(|b| b != &b':');
r.consume(": ").unwrap(); let value = CowSlice::Own(r.read_while(|b| b != &b'\r').to_owned().into_boxed_slice());
r.consume("\r\n").unwrap();
if let Some(key) = RequestHeader::from_bytes(key_bytes) {
self.headers.append(key, value);
} else {
self.headers.insert_custom(
Slice::from_bytes(Box::leak(key_bytes.to_owned().into_boxed_slice())),
value,
)
}
}
if self.get_payload_size()?.is_some() {
self.payload = Option::from(CowSlice::Own(r.remaining().into()));
}
Ok(Some(()))
}
#[cfg(feature = "rt_worker")]
pub(crate) async fn take_over(
mut self: Pin<&mut Self>,
mut req: ::worker::Request,
env: ::worker::Env,
ctx: ::worker::Context,
) -> Result<(), crate::Response> {
use crate::Response;
self.context.load((ctx, env));
self.method = Method::from_worker(req.method()).ok_or_else(|| {
Response::NotImplemented().with_text("ohkami doesn't support `CONNECT`, `TRACE` method")
})?;
self.__url__.write(
req.url()
.map_err(|_| Response::BadRequest().with_text("Invalid request URL"))?,
);
crate::DEBUG!("Load __url__: {:?}", self.__url__);
unsafe {
let __url__ = self.__url__.assume_init_ref();
let path = Slice::from_bytes(__url__.path().as_bytes()).as_bytes();
self.query = QueryParams::new(__url__.query().unwrap_or_default().as_bytes());
self.path.init_with_request_bytes(path)?;
}
self.headers.take_over(req.headers());
self.payload = Some(CowSlice::Own(
req.bytes()
.await
.map_err(|_| {
Response::InternalServerError().with_text("Failed to read request payload")
})?
.into(),
));
if let Some(ip) = self.headers.get("cf-connecting-ip") {
self.ip = ip.parse().unwrap();
}
Ok(())
}
#[cfg(feature = "rt_lambda")]
pub(crate) fn take_over(
mut self: Pin<&mut Self>,
::lambda_runtime::LambdaEvent {
payload: req,
context: _,
}: ::lambda_runtime::LambdaEvent<crate::x_lambda::LambdaHTTPRequest>,
) -> Result<(), lambda_runtime::Error> {
self.__query__.write(req.rawQueryString.into_boxed_str());
unsafe {
self.query = QueryParams::new(self.__query__.assume_init_ref().as_bytes());
}
self.context.load(req.requestContext);
{
let path_bytes = unsafe {
let bytes = self.context.lambda().http.path.as_bytes();
std::slice::from_raw_parts(bytes.as_ptr(), bytes.len())
};
self.path
.init_with_request_bytes(path_bytes)
.map_err(|_| crate::util::ErrorMessage("unsupported path format".into()))?;
self.method = self.context.lambda().http.method;
self.ip = self.context.lambda().http.sourceIp;
}
self.headers = req.headers;
if !req.cookies.is_empty() {
self.headers.set().cookie(req.cookies.join("; "));
}
if let Some(body) = req.body {
self.payload = Some(CowSlice::Own(
(if req.isBase64Encoded {
crate::util::base64_decode(body)?
} else {
body.into_bytes()
})
.into_boxed_slice(),
));
}
Result::<(), lambda_runtime::Error>::Ok(())
}
}
impl Request {
#[inline]
pub fn payload(&self) -> Option<&[u8]> {
self.payload.as_deref()
}
}
const _: () = {
impl std::fmt::Debug for Request {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut d = f.debug_struct("Request");
let d = &mut d;
#[cfg(feature = "__rt__")]
{
d.field("ip", &self.ip);
}
d.field("method", &self.method)
.field("path", &self.path.str())
.field("queries", &self.query)
.field("headers", &self.headers);
if let Some(payload) = self.payload.as_ref().map(|cs| unsafe { cs.as_bytes() }) {
d.field("payload", &String::from_utf8_lossy(payload));
}
d.finish()
}
}
};
#[cfg(feature = "__rt__")]
#[cfg(test)]
const _: () = {
impl PartialEq for Request {
fn eq(&self, other: &Self) -> bool {
self.method == other.method
&& unsafe { self.path.normalized_bytes() == other.path.normalized_bytes() }
&& self.query == other.query
&& self.headers == other.headers
&& self.payload == other.payload
}
}
};