use bytes::Bytes;
use libc::{c_int, size_t};
use std::ffi::c_void;
use super::body::{hyper_body, hyper_buf};
use super::error::hyper_code;
use super::task::{hyper_task_return_type, AsTaskType};
use super::{UserDataPointer, HYPER_ITER_CONTINUE};
use crate::ext::{HeaderCaseMap, OriginalHeaderOrder, ReasonPhrase};
use crate::header::{HeaderName, HeaderValue};
use crate::{Body, HeaderMap, Method, Request, Response, Uri};
pub struct hyper_request(pub(super) Request<Body>);
pub struct hyper_response(pub(super) Response<Body>);
pub struct hyper_headers {
pub(super) headers: HeaderMap,
orig_casing: HeaderCaseMap,
orig_order: OriginalHeaderOrder,
}
pub(crate) struct RawHeaders(pub(crate) hyper_buf);
pub(crate) struct OnInformational {
func: hyper_request_on_informational_callback,
data: UserDataPointer,
}
type hyper_request_on_informational_callback = extern "C" fn(*mut c_void, *mut hyper_response);
ffi_fn! {
fn hyper_request_new() -> *mut hyper_request {
Box::into_raw(Box::new(hyper_request(Request::new(Body::empty()))))
} ?= std::ptr::null_mut()
}
ffi_fn! {
fn hyper_request_free(req: *mut hyper_request) {
drop(non_null!(Box::from_raw(req) ?= ()));
}
}
ffi_fn! {
fn hyper_request_set_method(req: *mut hyper_request, method: *const u8, method_len: size_t) -> hyper_code {
let bytes = unsafe {
std::slice::from_raw_parts(method, method_len as usize)
};
let req = non_null!(&mut *req ?= hyper_code::HYPERE_INVALID_ARG);
match Method::from_bytes(bytes) {
Ok(m) => {
*req.0.method_mut() = m;
hyper_code::HYPERE_OK
},
Err(_) => {
hyper_code::HYPERE_INVALID_ARG
}
}
}
}
ffi_fn! {
fn hyper_request_set_uri(req: *mut hyper_request, uri: *const u8, uri_len: size_t) -> hyper_code {
let bytes = unsafe {
std::slice::from_raw_parts(uri, uri_len as usize)
};
let req = non_null!(&mut *req ?= hyper_code::HYPERE_INVALID_ARG);
match Uri::from_maybe_shared(bytes) {
Ok(u) => {
*req.0.uri_mut() = u;
hyper_code::HYPERE_OK
},
Err(_) => {
hyper_code::HYPERE_INVALID_ARG
}
}
}
}
ffi_fn! {
fn hyper_request_set_uri_parts(
req: *mut hyper_request,
scheme: *const u8,
scheme_len: size_t,
authority: *const u8,
authority_len: size_t,
path_and_query: *const u8,
path_and_query_len: size_t
) -> hyper_code {
let mut builder = Uri::builder();
if !scheme.is_null() {
let scheme_bytes = unsafe {
std::slice::from_raw_parts(scheme, scheme_len as usize)
};
builder = builder.scheme(scheme_bytes);
}
if !authority.is_null() {
let authority_bytes = unsafe {
std::slice::from_raw_parts(authority, authority_len as usize)
};
builder = builder.authority(authority_bytes);
}
if !path_and_query.is_null() {
let path_and_query_bytes = unsafe {
std::slice::from_raw_parts(path_and_query, path_and_query_len as usize)
};
builder = builder.path_and_query(path_and_query_bytes);
}
match builder.build() {
Ok(u) => {
*unsafe { &mut *req }.0.uri_mut() = u;
hyper_code::HYPERE_OK
},
Err(_) => {
hyper_code::HYPERE_INVALID_ARG
}
}
}
}
ffi_fn! {
fn hyper_request_set_version(req: *mut hyper_request, version: c_int) -> hyper_code {
use http::Version;
let req = non_null!(&mut *req ?= hyper_code::HYPERE_INVALID_ARG);
*req.0.version_mut() = match version {
super::HYPER_HTTP_VERSION_NONE => Version::HTTP_11,
super::HYPER_HTTP_VERSION_1_0 => Version::HTTP_10,
super::HYPER_HTTP_VERSION_1_1 => Version::HTTP_11,
super::HYPER_HTTP_VERSION_2 => Version::HTTP_2,
_ => {
return hyper_code::HYPERE_INVALID_ARG;
}
};
hyper_code::HYPERE_OK
}
}
ffi_fn! {
fn hyper_request_headers(req: *mut hyper_request) -> *mut hyper_headers {
hyper_headers::get_or_default(unsafe { &mut *req }.0.extensions_mut())
} ?= std::ptr::null_mut()
}
ffi_fn! {
fn hyper_request_set_body(req: *mut hyper_request, body: *mut hyper_body) -> hyper_code {
let body = non_null!(Box::from_raw(body) ?= hyper_code::HYPERE_INVALID_ARG);
let req = non_null!(&mut *req ?= hyper_code::HYPERE_INVALID_ARG);
*req.0.body_mut() = body.0;
hyper_code::HYPERE_OK
}
}
ffi_fn! {
fn hyper_request_on_informational(req: *mut hyper_request, callback: hyper_request_on_informational_callback, data: *mut c_void) -> hyper_code {
let ext = OnInformational {
func: callback,
data: UserDataPointer(data),
};
let req = non_null!(&mut *req ?= hyper_code::HYPERE_INVALID_ARG);
req.0.extensions_mut().insert(ext);
hyper_code::HYPERE_OK
}
}
impl hyper_request {
pub(super) fn finalize_request(&mut self) {
if let Some(headers) = self.0.extensions_mut().remove::<hyper_headers>() {
*self.0.headers_mut() = headers.headers;
self.0.extensions_mut().insert(headers.orig_casing);
self.0.extensions_mut().insert(headers.orig_order);
}
}
}
ffi_fn! {
fn hyper_response_free(resp: *mut hyper_response) {
drop(non_null!(Box::from_raw(resp) ?= ()));
}
}
ffi_fn! {
fn hyper_response_status(resp: *const hyper_response) -> u16 {
non_null!(&*resp ?= 0).0.status().as_u16()
}
}
ffi_fn! {
fn hyper_response_reason_phrase(resp: *const hyper_response) -> *const u8 {
non_null!(&*resp ?= std::ptr::null()).reason_phrase().as_ptr()
} ?= std::ptr::null()
}
ffi_fn! {
fn hyper_response_reason_phrase_len(resp: *const hyper_response) -> size_t {
non_null!(&*resp ?= 0).reason_phrase().len()
}
}
ffi_fn! {
fn hyper_response_headers_raw(resp: *const hyper_response) -> *const hyper_buf {
let resp = non_null!(&*resp ?= std::ptr::null());
match resp.0.extensions().get::<RawHeaders>() {
Some(raw) => &raw.0,
None => std::ptr::null(),
}
} ?= std::ptr::null()
}
ffi_fn! {
fn hyper_response_version(resp: *const hyper_response) -> c_int {
use http::Version;
match non_null!(&*resp ?= 0).0.version() {
Version::HTTP_10 => super::HYPER_HTTP_VERSION_1_0,
Version::HTTP_11 => super::HYPER_HTTP_VERSION_1_1,
Version::HTTP_2 => super::HYPER_HTTP_VERSION_2,
_ => super::HYPER_HTTP_VERSION_NONE,
}
}
}
ffi_fn! {
fn hyper_response_headers(resp: *mut hyper_response) -> *mut hyper_headers {
hyper_headers::get_or_default(unsafe { &mut *resp }.0.extensions_mut())
} ?= std::ptr::null_mut()
}
ffi_fn! {
fn hyper_response_body(resp: *mut hyper_response) -> *mut hyper_body {
let body = std::mem::take(non_null!(&mut *resp ?= std::ptr::null_mut()).0.body_mut());
Box::into_raw(Box::new(hyper_body(body)))
} ?= std::ptr::null_mut()
}
impl hyper_response {
pub(super) fn wrap(mut resp: Response<Body>) -> hyper_response {
let headers = std::mem::take(resp.headers_mut());
let orig_casing = resp
.extensions_mut()
.remove::<HeaderCaseMap>()
.unwrap_or_else(HeaderCaseMap::default);
let orig_order = resp
.extensions_mut()
.remove::<OriginalHeaderOrder>()
.unwrap_or_else(OriginalHeaderOrder::default);
resp.extensions_mut().insert(hyper_headers {
headers,
orig_casing,
orig_order,
});
hyper_response(resp)
}
fn reason_phrase(&self) -> &[u8] {
if let Some(reason) = self.0.extensions().get::<ReasonPhrase>() {
return reason.as_bytes();
}
if let Some(reason) = self.0.status().canonical_reason() {
return reason.as_bytes();
}
&[]
}
}
unsafe impl AsTaskType for hyper_response {
fn as_task_type(&self) -> hyper_task_return_type {
hyper_task_return_type::HYPER_TASK_RESPONSE
}
}
type hyper_headers_foreach_callback =
extern "C" fn(*mut c_void, *const u8, size_t, *const u8, size_t) -> c_int;
impl hyper_headers {
pub(super) fn get_or_default(ext: &mut http::Extensions) -> &mut hyper_headers {
if let None = ext.get_mut::<hyper_headers>() {
ext.insert(hyper_headers::default());
}
ext.get_mut::<hyper_headers>().unwrap()
}
}
ffi_fn! {
fn hyper_headers_foreach(headers: *const hyper_headers, func: hyper_headers_foreach_callback, userdata: *mut c_void) {
let headers = non_null!(&*headers ?= ());
let mut ordered_iter = headers.orig_order.get_in_order().peekable();
if ordered_iter.peek().is_some() {
for (name, idx) in ordered_iter {
let (name_ptr, name_len) = if let Some(orig_name) = headers.orig_casing.get_all(name).nth(*idx) {
(orig_name.as_ref().as_ptr(), orig_name.as_ref().len())
} else {
(
name.as_str().as_bytes().as_ptr(),
name.as_str().as_bytes().len(),
)
};
let val_ptr;
let val_len;
if let Some(value) = headers.headers.get_all(name).iter().nth(*idx) {
val_ptr = value.as_bytes().as_ptr();
val_len = value.as_bytes().len();
} else {
return;
}
if HYPER_ITER_CONTINUE != func(userdata, name_ptr, name_len, val_ptr, val_len) {
return;
}
}
} else {
for name in headers.headers.keys() {
let mut names = headers.orig_casing.get_all(name);
for value in headers.headers.get_all(name) {
let (name_ptr, name_len) = if let Some(orig_name) = names.next() {
(orig_name.as_ref().as_ptr(), orig_name.as_ref().len())
} else {
(
name.as_str().as_bytes().as_ptr(),
name.as_str().as_bytes().len(),
)
};
let val_ptr = value.as_bytes().as_ptr();
let val_len = value.as_bytes().len();
if HYPER_ITER_CONTINUE != func(userdata, name_ptr, name_len, val_ptr, val_len) {
return;
}
}
}
}
}
}
ffi_fn! {
fn hyper_headers_set(headers: *mut hyper_headers, name: *const u8, name_len: size_t, value: *const u8, value_len: size_t) -> hyper_code {
let headers = non_null!(&mut *headers ?= hyper_code::HYPERE_INVALID_ARG);
match unsafe { raw_name_value(name, name_len, value, value_len) } {
Ok((name, value, orig_name)) => {
headers.headers.insert(&name, value);
headers.orig_casing.insert(name.clone(), orig_name.clone());
headers.orig_order.insert(name);
hyper_code::HYPERE_OK
}
Err(code) => code,
}
}
}
ffi_fn! {
fn hyper_headers_add(headers: *mut hyper_headers, name: *const u8, name_len: size_t, value: *const u8, value_len: size_t) -> hyper_code {
let headers = non_null!(&mut *headers ?= hyper_code::HYPERE_INVALID_ARG);
match unsafe { raw_name_value(name, name_len, value, value_len) } {
Ok((name, value, orig_name)) => {
headers.headers.append(&name, value);
headers.orig_casing.append(&name, orig_name.clone());
headers.orig_order.append(name);
hyper_code::HYPERE_OK
}
Err(code) => code,
}
}
}
impl Default for hyper_headers {
fn default() -> Self {
Self {
headers: Default::default(),
orig_casing: HeaderCaseMap::default(),
orig_order: OriginalHeaderOrder::default(),
}
}
}
unsafe fn raw_name_value(
name: *const u8,
name_len: size_t,
value: *const u8,
value_len: size_t,
) -> Result<(HeaderName, HeaderValue, Bytes), hyper_code> {
let name = std::slice::from_raw_parts(name, name_len);
let orig_name = Bytes::copy_from_slice(name);
let name = match HeaderName::from_bytes(name) {
Ok(name) => name,
Err(_) => return Err(hyper_code::HYPERE_INVALID_ARG),
};
let value = std::slice::from_raw_parts(value, value_len);
let value = match HeaderValue::from_bytes(value) {
Ok(val) => val,
Err(_) => return Err(hyper_code::HYPERE_INVALID_ARG),
};
Ok((name, value, orig_name))
}
impl OnInformational {
pub(crate) fn call(&mut self, resp: Response<Body>) {
let mut resp = hyper_response::wrap(resp);
(self.func)(self.data.0, &mut resp);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_headers_foreach_cases_preserved() {
let mut headers = hyper_headers::default();
let name1 = b"Set-CookiE";
let value1 = b"a=b";
hyper_headers_add(
&mut headers,
name1.as_ptr(),
name1.len(),
value1.as_ptr(),
value1.len(),
);
let name2 = b"SET-COOKIE";
let value2 = b"c=d";
hyper_headers_add(
&mut headers,
name2.as_ptr(),
name2.len(),
value2.as_ptr(),
value2.len(),
);
let mut vec = Vec::<u8>::new();
hyper_headers_foreach(&headers, concat, &mut vec as *mut _ as *mut c_void);
assert_eq!(vec, b"Set-CookiE: a=b\r\nSET-COOKIE: c=d\r\n");
extern "C" fn concat(
vec: *mut c_void,
name: *const u8,
name_len: usize,
value: *const u8,
value_len: usize,
) -> c_int {
unsafe {
let vec = &mut *(vec as *mut Vec<u8>);
let name = std::slice::from_raw_parts(name, name_len);
let value = std::slice::from_raw_parts(value, value_len);
vec.extend(name);
vec.extend(b": ");
vec.extend(value);
vec.extend(b"\r\n");
}
HYPER_ITER_CONTINUE
}
}
#[cfg(all(feature = "http1", feature = "ffi"))]
#[test]
fn test_headers_foreach_order_preserved() {
let mut headers = hyper_headers::default();
let name1 = b"Set-CookiE";
let value1 = b"a=b";
hyper_headers_add(
&mut headers,
name1.as_ptr(),
name1.len(),
value1.as_ptr(),
value1.len(),
);
let name2 = b"Content-Encoding";
let value2 = b"gzip";
hyper_headers_add(
&mut headers,
name2.as_ptr(),
name2.len(),
value2.as_ptr(),
value2.len(),
);
let name3 = b"SET-COOKIE";
let value3 = b"c=d";
hyper_headers_add(
&mut headers,
name3.as_ptr(),
name3.len(),
value3.as_ptr(),
value3.len(),
);
let mut vec = Vec::<u8>::new();
hyper_headers_foreach(&headers, concat, &mut vec as *mut _ as *mut c_void);
println!("{}", std::str::from_utf8(&vec).unwrap());
assert_eq!(
vec,
b"Set-CookiE: a=b\r\nContent-Encoding: gzip\r\nSET-COOKIE: c=d\r\n"
);
extern "C" fn concat(
vec: *mut c_void,
name: *const u8,
name_len: usize,
value: *const u8,
value_len: usize,
) -> c_int {
unsafe {
let vec = &mut *(vec as *mut Vec<u8>);
let name = std::slice::from_raw_parts(name, name_len);
let value = std::slice::from_raw_parts(value, value_len);
vec.extend(name);
vec.extend(b": ");
vec.extend(value);
vec.extend(b"\r\n");
}
HYPER_ITER_CONTINUE
}
}
}