use crate::abi::{self, XqdStatus};
use crate::body::{Body, BodyHandle};
use crate::Error;
use bytes::{BufMut, BytesMut};
use http::header::{HeaderName, HeaderValue};
use http::{Response, StatusCode, Version};
use lazy_static::lazy_static;
use std::convert::TryFrom;
use std::io::Write;
use std::sync::Mutex;
#[derive(Debug, Eq, Hash, PartialEq)]
#[repr(transparent)]
pub struct ResponseHandle {
pub(crate) handle: u32,
}
impl ResponseHandle {
pub const INVALID: Self = ResponseHandle {
handle: fastly_shared::INVALID_RESPONSE_HANDLE,
};
pub fn is_invalid(&self) -> bool {
self == &Self::INVALID
}
pub(crate) unsafe fn handle(&self) -> Self {
Self {
handle: self.handle,
}
}
pub fn new() -> Result<Self, Error> {
let mut handle = ResponseHandle::INVALID;
let status = unsafe { abi::xqd_resp_new(&mut handle) };
if status.is_err() || handle.is_invalid() {
Err(Error::new("xqd_resp_new failed"))
} else {
Ok(handle)
}
}
pub fn get_header_names<'a>(
&'a self,
max_len: usize,
) -> impl Iterator<Item = Result<HeaderName, Error>> + 'a {
abi::MultiValueHostcall::new(
b'\0',
max_len,
move |buf, buf_size, cursor, ending_cursor, nwritten| unsafe {
abi::xqd_resp_header_names_get(
self.handle(),
buf,
buf_size,
cursor,
ending_cursor,
nwritten,
)
},
)
.map(|res| {
res.and_then(|name_bytes| {
HeaderName::from_bytes(&name_bytes)
.map_err(|e| Error::new(format!("invalid header name: {}", e)))
})
})
}
pub fn get_header_values<'a>(
&'a self,
name: &'a HeaderName,
max_len: usize,
) -> impl Iterator<Item = Result<HeaderValue, Error>> + 'a {
abi::MultiValueHostcall::new(
b'\0',
max_len,
move |buf, buf_size, cursor, ending_cursor, nwritten| unsafe {
let name: &[u8] = name.as_ref();
abi::xqd_resp_header_values_get(
self.handle(),
name.as_ptr(),
name.len(),
buf,
buf_size,
cursor,
ending_cursor,
nwritten,
)
},
)
.map(|res| {
res.map(|value_bytes| unsafe {
HeaderValue::from_maybe_shared_unchecked(value_bytes)
})
})
}
pub fn set_header_values<'a, I>(&mut self, name: &HeaderName, values: I) -> Result<(), Error>
where
I: IntoIterator<Item = &'a HeaderValue>,
{
let mut buf = vec![];
for value in values {
buf.put(value.as_bytes());
buf.put_u8(b'\0');
}
let name: &[u8] = name.as_ref();
let status = unsafe {
abi::xqd_resp_header_values_set(
self.handle(),
name.as_ptr(),
name.len(),
buf.as_ptr(),
buf.len(),
)
};
if status.is_err() {
Err(Error::new("xqd_req_header_values_set failed"))
} else {
Ok(())
}
}
pub fn get_header_value(
&self,
name: &HeaderName,
max_len: usize,
) -> Result<HeaderValue, Error> {
let name: &[u8] = name.as_ref();
let mut buf = BytesMut::with_capacity(max_len);
let mut nwritten = 0;
let status = unsafe {
abi::xqd_resp_header_value_get(
self.handle(),
name.as_ptr(),
name.len(),
buf.as_mut_ptr(),
buf.capacity(),
&mut nwritten,
)
};
if status.is_err() {
return Err(Error::new("xqd_resp_header_value_get returned error"));
}
assert!(nwritten <= buf.capacity(), "hostcall wrote too many bytes");
unsafe {
buf.set_len(nwritten);
}
Ok(HeaderValue::from_bytes(&buf).map_err(|_| Error::new("invalid header"))?)
}
pub fn insert_header(&mut self, name: &HeaderName, value: &HeaderValue) -> Result<(), Error> {
let name_bytes: &[u8] = name.as_ref();
let value_bytes: &[u8] = value.as_ref();
let status = unsafe {
abi::xqd_resp_header_insert(
self.handle(),
name_bytes.as_ptr(),
name_bytes.len(),
value_bytes.as_ptr(),
value_bytes.len(),
)
};
if status.is_err() {
Err(Error::new("xqd_resp_header_insert returned error"))
} else {
Ok(())
}
}
pub fn append_header(&mut self, name: &HeaderName, value: &HeaderValue) -> Result<(), Error> {
let name_bytes: &[u8] = name.as_ref();
let value_bytes: &[u8] = value.as_ref();
let status = unsafe {
abi::xqd_resp_header_append(
self.handle(),
name_bytes.as_ptr(),
name_bytes.len(),
value_bytes.as_ptr(),
value_bytes.len(),
)
};
if status.is_err() {
Err(Error::new("xqd_resp_header_append returned error"))
} else {
Ok(())
}
}
pub fn set_status(&mut self, status: StatusCode) {
let status = unsafe { abi::xqd_resp_status_set(self.handle(), status.as_u16()) };
assert_eq!(
status,
XqdStatus::OK,
"setting a StatusCode should always succeed"
);
}
pub fn get_status(&self) -> Result<StatusCode, Error> {
let mut status = 0;
let xqd_status = unsafe { abi::xqd_resp_status_get(self.handle(), &mut status) };
if xqd_status.is_err() {
Err(Error::new("xqd_resp_status_get failed"))
} else {
StatusCode::from_u16(status)
.map_err(|e| Error::new(format!("invalid status code: {}", e)))
}
}
pub fn get_version(&self) -> Result<Version, Error> {
let mut version = 0;
let status = unsafe { abi::xqd_resp_version_get(self.handle(), &mut version) };
if status.is_err() {
Err(Error::new("xqd_resp_version_get failed"))
} else {
abi::HttpVersion::try_from(version)
.map(Into::into)
.map_err(Error::new)
}
}
pub fn set_version(&mut self, v: Version) -> Result<(), Error> {
let status =
unsafe { abi::xqd_resp_version_set(self.handle(), abi::HttpVersion::from(v) as u32) };
if status.is_err() {
Err(Error::new("xqd_req_version_get failed"))
} else {
Ok(())
}
}
pub fn send_downstream(self, body: BodyHandle) -> Result<(), Error> {
let status = unsafe { abi::xqd_resp_send_downstream(self, body) };
if status.is_err() {
Err(Error::new("xqd_resp_send_downstream failed"))
} else {
Ok(())
}
}
}
pub trait ResponseExt {
fn send_downstream(self) -> Result<(), Error>;
fn inner_to_body(self) -> Result<Response<Body>, Error>;
fn inner_to_bytes(self) -> Result<Response<Vec<u8>>, Error>;
}
impl ResponseExt for Response<Body> {
fn send_downstream(self) -> Result<(), Error> {
lazy_static! {
static ref SENT: Mutex<bool> = Mutex::new(false);
}
let mut sent = SENT.lock().unwrap();
if *sent {
return Err(Error::new(
"cannot send more than one downstream response per execution",
));
}
let (parts, body) = self.into_parts();
let mut resp_handle = ResponseHandle::new()?;
for name in parts.headers.keys() {
resp_handle.set_header_values(name, parts.headers.get_all(name))?;
}
resp_handle.set_status(parts.status);
resp_handle.set_version(parts.version)?;
resp_handle.send_downstream(body.into_handle()?)?;
*sent = true;
Ok(())
}
fn inner_to_body(self) -> Result<Response<Body>, Error> {
Ok(self)
}
fn inner_to_bytes(self) -> Result<Response<Vec<u8>>, Error> {
let (parts, body) = self.into_parts();
Ok(Response::from_parts(parts, body.into_bytes()?))
}
}
impl<T: AsRef<[u8]>> ResponseExt for Response<T> {
fn send_downstream(self) -> Result<(), Error> {
let mut body = Body::new()?;
body.write_all(self.body().as_ref())
.map_err(|e| Error::new(format!("{}", e)))?;
self.map(|_| body).send_downstream()
}
fn inner_to_body(self) -> Result<Response<Body>, Error> {
let mut body = Body::new()?;
body.write_all(self.body().as_ref())
.map_err(|e| Error::new(format!("{}", e)))?;
Ok(self.map(|_| body))
}
fn inner_to_bytes(self) -> Result<Response<Vec<u8>>, Error> {
Ok(self.map(|b| b.as_ref().to_vec()))
}
}