fastly 0.2.0-alpha1

Fastly Compute@Edge API
Documentation
use crate::abi::{self, XqdStatus};
use crate::body::{Body, BodyHandle};
use crate::Error;
use bytes::{BufMut, BytesMut};
use http::header::{HeaderName, HeaderValue};
use http::{Response, StatusCode, Version};
use lazy_static::lazy_static;
use std::convert::TryFrom;
use std::io::Write;
use std::sync::Mutex;

#[derive(Debug, Eq, Hash, PartialEq)]
#[repr(transparent)]
pub struct ResponseHandle {
    pub(crate) handle: u32,
}

impl ResponseHandle {
    pub const INVALID: Self = ResponseHandle {
        handle: fastly_shared::INVALID_RESPONSE_HANDLE,
    };

    pub fn is_invalid(&self) -> bool {
        self == &Self::INVALID
    }

    /// Get an owned `ResponseHandle` from a borrowed one.
    ///
    /// This should only be used when calling the raw ABI directly.
    pub(crate) unsafe fn handle(&self) -> Self {
        Self {
            handle: self.handle,
        }
    }

    pub fn new() -> Result<Self, Error> {
        let mut handle = ResponseHandle::INVALID;
        let status = unsafe { abi::xqd_resp_new(&mut handle) };
        if status.is_err() || handle.is_invalid() {
            Err(Error::new("xqd_resp_new failed"))
        } else {
            Ok(handle)
        }
    }

    pub fn get_header_names<'a>(
        &'a self,
        max_len: usize,
    ) -> impl Iterator<Item = Result<HeaderName, Error>> + 'a {
        abi::MultiValueHostcall::new(
            b'\0',
            max_len,
            move |buf, buf_size, cursor, ending_cursor, nwritten| unsafe {
                abi::xqd_resp_header_names_get(
                    self.handle(),
                    buf,
                    buf_size,
                    cursor,
                    ending_cursor,
                    nwritten,
                )
            },
        )
        .map(|res| {
            res.and_then(|name_bytes| {
                HeaderName::from_bytes(&name_bytes)
                    .map_err(|e| Error::new(format!("invalid header name: {}", e)))
            })
        })
    }

    pub fn get_header_values<'a>(
        &'a self,
        name: &'a HeaderName,
        max_len: usize,
    ) -> impl Iterator<Item = Result<HeaderValue, Error>> + 'a {
        abi::MultiValueHostcall::new(
            b'\0',
            max_len,
            move |buf, buf_size, cursor, ending_cursor, nwritten| unsafe {
                let name: &[u8] = name.as_ref();
                abi::xqd_resp_header_values_get(
                    self.handle(),
                    name.as_ptr(),
                    name.len(),
                    buf,
                    buf_size,
                    cursor,
                    ending_cursor,
                    nwritten,
                )
            },
        )
        .map(|res| {
            res.map(|value_bytes| unsafe {
                // we trust that the hostcall is giving us valid header bytes
                HeaderValue::from_maybe_shared_unchecked(value_bytes)
            })
        })
    }

    pub fn set_header_values<'a, I>(&mut self, name: &HeaderName, values: I) -> Result<(), Error>
    where
        I: IntoIterator<Item = &'a HeaderValue>,
    {
        // build a buffer of all the values, each terminated by a nul byte
        let mut buf = vec![];
        for value in values {
            buf.put(value.as_bytes());
            buf.put_u8(b'\0');
        }

        let name: &[u8] = name.as_ref();
        let status = unsafe {
            abi::xqd_resp_header_values_set(
                self.handle(),
                name.as_ptr(),
                name.len(),
                buf.as_ptr(),
                buf.len(),
            )
        };

        if status.is_err() {
            Err(Error::new("xqd_req_header_values_set failed"))
        } else {
            Ok(())
        }
    }

    pub fn get_header_value(
        &self,
        name: &HeaderName,
        max_len: usize,
    ) -> Result<HeaderValue, Error> {
        let name: &[u8] = name.as_ref();
        let mut buf = BytesMut::with_capacity(max_len);
        let mut nwritten = 0;
        let status = unsafe {
            abi::xqd_resp_header_value_get(
                self.handle(),
                name.as_ptr(),
                name.len(),
                buf.as_mut_ptr(),
                buf.capacity(),
                &mut nwritten,
            )
        };
        if status.is_err() {
            return Err(Error::new("xqd_resp_header_value_get returned error"));
        }
        assert!(nwritten <= buf.capacity(), "hostcall wrote too many bytes");
        unsafe {
            buf.set_len(nwritten);
        }
        Ok(HeaderValue::from_bytes(&buf).map_err(|_| Error::new("invalid header"))?)
    }

    pub fn insert_header(&mut self, name: &HeaderName, value: &HeaderValue) -> Result<(), Error> {
        let name_bytes: &[u8] = name.as_ref();
        let value_bytes: &[u8] = value.as_ref();
        let status = unsafe {
            abi::xqd_resp_header_insert(
                self.handle(),
                name_bytes.as_ptr(),
                name_bytes.len(),
                value_bytes.as_ptr(),
                value_bytes.len(),
            )
        };
        if status.is_err() {
            Err(Error::new("xqd_resp_header_insert returned error"))
        } else {
            Ok(())
        }
    }

    pub fn append_header(&mut self, name: &HeaderName, value: &HeaderValue) -> Result<(), Error> {
        let name_bytes: &[u8] = name.as_ref();
        let value_bytes: &[u8] = value.as_ref();
        let status = unsafe {
            abi::xqd_resp_header_append(
                self.handle(),
                name_bytes.as_ptr(),
                name_bytes.len(),
                value_bytes.as_ptr(),
                value_bytes.len(),
            )
        };
        if status.is_err() {
            Err(Error::new("xqd_resp_header_append returned error"))
        } else {
            Ok(())
        }
    }

    pub fn set_status(&mut self, status: StatusCode) {
        let status = unsafe { abi::xqd_resp_status_set(self.handle(), status.as_u16()) };
        assert_eq!(
            status,
            XqdStatus::OK,
            "setting a StatusCode should always succeed"
        );
    }

    pub fn get_status(&self) -> Result<StatusCode, Error> {
        let mut status = 0;
        let xqd_status = unsafe { abi::xqd_resp_status_get(self.handle(), &mut status) };
        if xqd_status.is_err() {
            Err(Error::new("xqd_resp_status_get failed"))
        } else {
            StatusCode::from_u16(status)
                .map_err(|e| Error::new(format!("invalid status code: {}", e)))
        }
    }

    pub fn get_version(&self) -> Result<Version, Error> {
        let mut version = 0;
        let status = unsafe { abi::xqd_resp_version_get(self.handle(), &mut version) };
        if status.is_err() {
            Err(Error::new("xqd_resp_version_get failed"))
        } else {
            abi::HttpVersion::try_from(version)
                .map(Into::into)
                .map_err(Error::new)
        }
    }

    pub fn set_version(&mut self, v: Version) -> Result<(), Error> {
        let status =
            unsafe { abi::xqd_resp_version_set(self.handle(), abi::HttpVersion::from(v) as u32) };
        if status.is_err() {
            Err(Error::new("xqd_req_version_get failed"))
        } else {
            Ok(())
        }
    }

    pub fn send_downstream(self, body: BodyHandle) -> Result<(), Error> {
        let status = unsafe { abi::xqd_resp_send_downstream(self, body) };
        if status.is_err() {
            Err(Error::new("xqd_resp_send_downstream failed"))
        } else {
            Ok(())
        }
    }
}

pub trait ResponseExt {
    fn send_downstream(self) -> Result<(), Error>;

    /// Replace the body of a response with a `Body` with the same contents.
    fn inner_to_body(self) -> Result<Response<Body>, Error>;

    /// Replace the body of a response with the remaining contents of its body.
    ///
    /// Note that this will involve copying and buffering the body, and so should only be used for
    /// convenience on small response bodies.
    fn inner_to_bytes(self) -> Result<Response<Vec<u8>>, Error>;
}

impl ResponseExt for Response<Body> {
    fn send_downstream(self) -> Result<(), Error> {
        lazy_static! {
            static ref SENT: Mutex<bool> = Mutex::new(false);
        }

        let mut sent = SENT.lock().unwrap();
        if *sent {
            return Err(Error::new(
                "cannot send more than one downstream response per execution",
            ));
        }

        let (parts, body) = self.into_parts();

        let mut resp_handle = ResponseHandle::new()?;

        for name in parts.headers.keys() {
            resp_handle.set_header_values(name, parts.headers.get_all(name))?;
        }

        resp_handle.set_status(parts.status);

        resp_handle.set_version(parts.version)?;

        resp_handle.send_downstream(body.into_handle()?)?;

        *sent = true;
        Ok(())
    }

    fn inner_to_body(self) -> Result<Response<Body>, Error> {
        Ok(self)
    }

    fn inner_to_bytes(self) -> Result<Response<Vec<u8>>, Error> {
        let (parts, body) = self.into_parts();
        Ok(Response::from_parts(parts, body.into_bytes()?))
    }
}

impl<T: AsRef<[u8]>> ResponseExt for Response<T> {
    fn send_downstream(self) -> Result<(), Error> {
        let mut body = Body::new()?;
        body.write_all(self.body().as_ref())
            .map_err(|e| Error::new(format!("{}", e)))?;
        self.map(|_| body).send_downstream()
    }

    fn inner_to_body(self) -> Result<Response<Body>, Error> {
        let mut body = Body::new()?;

        body.write_all(self.body().as_ref())
            .map_err(|e| Error::new(format!("{}", e)))?;

        Ok(self.map(|_| body))
    }

    fn inner_to_bytes(self) -> Result<Response<Vec<u8>>, Error> {
        Ok(self.map(|b| b.as_ref().to_vec()))
    }
}