use crate::abi;
use crate::body::{Body, BodyHandle};
use crate::response::ResponseHandle;
use crate::Error;
use bytes::{BufMut, BytesMut};
use http::header::{HeaderName, HeaderValue};
use http::{Method, Request, Response, Uri, Version};
use lazy_static::lazy_static;
use std::convert::TryFrom;
use std::io::Write;
use std::sync::Mutex;
#[derive(Debug, Eq, Hash, PartialEq)]
#[repr(transparent)]
pub struct RequestHandle {
handle: u32,
}
lazy_static! {
pub(crate) static ref GOT_DOWNSTREAM: Mutex<bool> = Mutex::new(false);
}
impl RequestHandle {
pub const INVALID: Self = RequestHandle {
handle: fastly_shared::INVALID_REQUEST_HANDLE,
};
pub fn is_invalid(&self) -> bool {
self == &Self::INVALID
}
pub(crate) unsafe fn handle(&self) -> Self {
Self {
handle: self.handle,
}
}
pub fn downstream() -> Result<Self, Error> {
let mut got = GOT_DOWNSTREAM.lock().unwrap();
if *got {
return Err(Error::new(
"cannot get more than one handle to the downstream request per execution",
));
}
let mut handle = RequestHandle::INVALID;
let status = unsafe { abi::xqd_req_body_downstream_get(&mut handle, std::ptr::null_mut()) };
if status.is_err() || handle.is_invalid() {
Err(Error::new("xqd_req_body_downstream_get failed"))
} else {
*got = true;
Ok(handle)
}
}
pub fn new() -> Result<Self, Error> {
let mut handle = RequestHandle::INVALID;
let status = unsafe { abi::xqd_req_new(&mut handle) };
if status.is_err() || handle.is_invalid() {
Err(Error::new("xqd_req_new failed"))
} else {
Ok(handle)
}
}
pub fn get_header_names<'a>(
&'a self,
buf_size: usize,
) -> impl Iterator<Item = Result<HeaderName, Error>> + 'a {
abi::MultiValueHostcall::new(
b'\0',
buf_size,
move |buf, buf_size, cursor, ending_cursor, nwritten| unsafe {
abi::xqd_req_header_names_get(
self.handle(),
buf,
buf_size,
cursor,
ending_cursor,
nwritten,
)
},
)
.map(|res| {
res.and_then(|name_bytes| {
HeaderName::from_bytes(&name_bytes)
.map_err(|e| Error::new(format!("invalid header name: {}", e)))
})
})
}
pub fn get_header_values<'a>(
&'a self,
name: &'a HeaderName,
buf_size: usize,
) -> impl Iterator<Item = Result<HeaderValue, Error>> + 'a {
abi::MultiValueHostcall::new(
b'\0',
buf_size,
move |buf, buf_size, cursor, ending_cursor, nwritten| unsafe {
let name: &[u8] = name.as_ref();
abi::xqd_req_header_values_get(
self.handle(),
name.as_ptr(),
name.len(),
buf,
buf_size,
cursor,
ending_cursor,
nwritten,
)
},
)
.map(|res| {
res.map(|value_bytes| unsafe {
HeaderValue::from_maybe_shared_unchecked(value_bytes)
})
})
}
pub fn set_header_values<'a, I>(&mut self, name: &HeaderName, values: I) -> Result<(), Error>
where
I: IntoIterator<Item = &'a HeaderValue>,
{
let mut buf = vec![];
for value in values {
buf.put(value.as_bytes());
buf.put_u8(b'\0');
}
let name: &[u8] = name.as_ref();
let status = unsafe {
abi::xqd_req_header_values_set(
self.handle(),
name.as_ptr(),
name.len(),
buf.as_ptr(),
buf.len(),
)
};
if status.is_err() {
Err(Error::new("xqd_req_header_values_set failed"))
} else {
Ok(())
}
}
pub fn insert_header(&mut self, name: &HeaderName, value: &HeaderValue) -> Result<(), Error> {
let name_bytes: &[u8] = name.as_ref();
let value_bytes: &[u8] = value.as_ref();
let status = unsafe {
abi::xqd_req_header_insert(
self.handle(),
name_bytes.as_ptr(),
name_bytes.len(),
value_bytes.as_ptr(),
value_bytes.len(),
)
};
if status.is_err() {
Err(Error::new("xqd_req_header_insert returned error"))
} else {
Ok(())
}
}
pub fn append_header(&mut self, name: &HeaderName, value: &HeaderValue) -> Result<(), Error> {
let name_bytes: &[u8] = name.as_ref();
let value_bytes: &[u8] = value.as_ref();
let status = unsafe {
abi::xqd_req_header_append(
self.handle(),
name_bytes.as_ptr(),
name_bytes.len(),
value_bytes.as_ptr(),
value_bytes.len(),
)
};
if status.is_err() {
Err(Error::new("xqd_req_header_append returned error"))
} else {
Ok(())
}
}
pub fn get_version(&self) -> Result<Version, Error> {
let mut version = 0;
let status = unsafe { abi::xqd_req_version_get(self.handle(), &mut version) };
if status.is_err() {
Err(Error::new("xqd_req_version_get failed"))
} else {
abi::HttpVersion::try_from(version)
.map(Into::into)
.map_err(Error::new)
}
}
pub fn set_version(&mut self, v: Version) -> Result<(), Error> {
let status =
unsafe { abi::xqd_req_version_set(self.handle(), abi::HttpVersion::from(v) as u32) };
if status.is_err() {
Err(Error::new("xqd_req_version_get failed"))
} else {
Ok(())
}
}
pub fn get_method(&self, max_length: usize) -> Result<Method, Error> {
let mut method_bytes = Vec::with_capacity(max_length);
let mut nwritten = 0;
let status = unsafe {
abi::xqd_req_method_get(
self.handle(),
method_bytes.as_mut_ptr(),
method_bytes.capacity(),
&mut nwritten,
)
};
if status.is_err() {
return Err(Error::new("xqd_req_method_get failed"));
}
assert!(
nwritten <= method_bytes.capacity(),
"xqd_req_method_get wrote too many bytes"
);
unsafe {
method_bytes.set_len(nwritten);
}
Method::from_bytes(&method_bytes).map_err(|_| {
Error::new(format!(
"invalid method: {}",
String::from_utf8_lossy(&method_bytes)
))
})
}
pub fn set_method(&self, method: &Method) -> Result<(), Error> {
let method_bytes = method.as_str().as_bytes();
let status = unsafe {
abi::xqd_req_method_set(self.handle(), method_bytes.as_ptr(), method_bytes.len())
};
if status.is_err() {
Err(Error::new("xqd_req_method_set failed"))
} else {
Ok(())
}
}
pub fn get_uri(&self, max_length: usize) -> Result<Uri, Error> {
let mut uri_bytes = BytesMut::with_capacity(max_length);
let mut nwritten = 0;
let status = unsafe {
abi::xqd_req_uri_get(
self.handle(),
uri_bytes.as_mut_ptr(),
uri_bytes.capacity(),
&mut nwritten,
)
};
if status.is_err() {
return Err(Error::new("xqd_req_uri_get failed"));
}
assert!(
nwritten <= uri_bytes.capacity(),
"xqd_req_uri_get wrote too many bytes"
);
unsafe {
uri_bytes.set_len(nwritten);
}
Uri::from_maybe_shared(uri_bytes.freeze())
.map_err(|e| Error::new(format!("invalid URI: {}", e)))
}
pub fn set_uri(&mut self, uri: &Uri) -> Result<(), Error> {
let uri_bytes = uri.to_string().into_bytes();
let status =
unsafe { abi::xqd_req_uri_set(self.handle(), uri_bytes.as_ptr(), uri_bytes.len()) };
if status.is_err() {
Err(Error::new("xqd_req_uri_set failed"))
} else {
Ok(())
}
}
pub fn send(
self,
body: BodyHandle,
backend: &str,
ttl: i32,
) -> Result<(ResponseHandle, BodyHandle), Error> {
let mut resp_handle = ResponseHandle::INVALID;
let mut resp_body_handle = BodyHandle::INVALID;
let status = unsafe {
abi::xqd_req_send(
self.handle(),
body.handle(),
backend.as_ptr(),
backend.len(),
ttl,
&mut resp_handle,
&mut resp_body_handle,
)
};
if status.is_err() || resp_handle.is_invalid() || resp_body_handle.is_invalid() {
Err(Error::new("xqd_req_send failed"))
} else {
Ok((resp_handle, resp_body_handle))
}
}
}
pub fn downstream_request_and_body_handles() -> Result<(RequestHandle, BodyHandle), Error> {
let mut got_req = crate::request::GOT_DOWNSTREAM.lock().unwrap();
let mut got_body = crate::body::GOT_DOWNSTREAM.lock().unwrap();
if *got_req || *got_body {
return Err(Error::new(
"cannot get more than one handle to the downstream request per execution",
));
}
let mut req_handle = RequestHandle::INVALID;
let mut body_handle = BodyHandle::INVALID;
let status = unsafe { abi::xqd_req_body_downstream_get(&mut req_handle, &mut body_handle) };
if status.is_err() || req_handle.is_invalid() || body_handle.is_invalid() {
Err(Error::new("xqd_req_body_downstream_get failed"))
} else {
*got_req = true;
*got_body = true;
Ok((req_handle, body_handle))
}
}
pub fn downstream_request() -> Result<Request<Body>, Error> {
let (req_handle, body_handle) = downstream_request_and_body_handles()?;
let mut req = Request::builder()
.version(req_handle.get_version()?)
.method(req_handle.get_method(crate::METHOD_MAX_LEN)?)
.uri(req_handle.get_uri(crate::URI_MAX_LEN)?);
for name in req_handle.get_header_names(crate::HEADER_NAME_MAX_LEN) {
let name = name?;
for value in req_handle.get_header_values(&name, crate::HEADER_VALUE_MAX_LEN) {
req = req.header(&name, value?);
}
}
Ok(req.body(body_handle.into())?)
}
pub trait RequestExt {
fn send(self, backend: &str, ttl: i32) -> Result<Response<Body>, Error>;
fn inner_to_body(self) -> Result<Request<Body>, Error>;
fn inner_to_bytes(self) -> Result<Request<Vec<u8>>, Error>;
}
impl RequestExt for Request<Body> {
fn send(self, backend: &str, ttl: i32) -> Result<Response<Body>, Error> {
let mut req = RequestHandle::new()?;
req.set_version(self.version())?;
req.set_method(self.method())?;
req.set_uri(self.uri())?;
for name in self.headers().keys() {
req.set_header_values(name, self.headers().get_all(name))?;
}
let (resp_handle, resp_body_handle) =
req.send(self.into_body().into_handle()?, backend, ttl)?;
let mut resp = Response::builder()
.status(resp_handle.get_status()?)
.version(resp_handle.get_version()?);
for name in resp_handle.get_header_names(crate::HEADER_NAME_MAX_LEN) {
let name = name?;
for value in resp_handle.get_header_values(&name, crate::HEADER_VALUE_MAX_LEN) {
resp = resp.header(&name, value?);
}
}
Ok(resp.body(resp_body_handle.into())?)
}
fn inner_to_body(self) -> Result<Request<Body>, Error> {
Ok(self)
}
fn inner_to_bytes(self) -> Result<Request<Vec<u8>>, Error> {
let (parts, body) = self.into_parts();
Ok(Request::from_parts(parts, body.into_bytes()?))
}
}
impl<T: AsRef<[u8]>> RequestExt for Request<T> {
fn send(self, backend: &str, ttl: i32) -> Result<Response<Body>, Error> {
self.inner_to_body()?.send(backend, ttl)
}
fn inner_to_body(self) -> Result<Request<Body>, Error> {
let mut body = Body::new()?;
body.write_all(self.body().as_ref())
.map_err(|e| Error::new(format!("{}", e)))?;
Ok(self.map(|_| body))
}
fn inner_to_bytes(self) -> Result<Request<Vec<u8>>, Error> {
Ok(self.map(|b| b.as_ref().to_vec()))
}
}