ript 0.1.2

Rust implementation of the InertiaJS protocol compatible with `riptc` for generating strong TypeScript bindings.
Documentation
//! Implements the logic for the Inertia extractor which is the entrypoint to inertia on the
//! backend and implements the public facing API of this crate for use with the inertia driver.

#[cfg(feature = "validator")]
mod form;

use std::{
    collections::VecDeque,
    convert::Infallible,
    ops::{ControlFlow, Deref},
    pin::Pin,
    sync::Arc,
    task::{Context, Poll},
};

use axum::{
    Extension, RequestPartsExt,
    extract::FromRequestParts,
    http::{HeaderMap, Method, StatusCode},
    response::{IntoResponse, Response},
};
use bon::{Builder, builder};
use serde::Serialize;
use std::future::Future;

use crate::{
    extension::InertiaExtension,
    prop::{PropBuilder, PropControlFlow, prop_builder},
};

/// The entrypoint to `Inertia` that is used by route handlers and response mapping middleware.
#[derive(Builder)]
pub struct Inertia {
    method: Method,
    header_map: HeaderMap,
    extension: Arc<InertiaExtension>,
}

impl<S> FromRequestParts<S> for Inertia
where
    S: Send + Sync,
{
    type Rejection = Infallible;

    async fn from_request_parts(
        parts: &mut axum::http::request::Parts,
        _state: &S,
    ) -> Result<Self, Self::Rejection> {
        let header_map: HeaderMap = parts.extract().await?;

        let extension = parts.extensions.get_or_insert_default();

        Ok(Self {
            method: parts.method.clone(),
            header_map,
            extension: Arc::clone(extension),
        })
    }
}

// A note to developers on the implementation...
//
// There are some weird tricks done here to make the following conditions true:
// 1. User's can pass async fns in as props that can also borrow from the route
// 2. Unsafe code is not required to allow this
// 3. Multiple prop render chains cannot be alive at the same time (this makes
// analysis possible by the compiler to basically walk down a flattened tree of the props)
//
// This makes code generation work well, and the developer experience very good. Do do this,
// it is very deliberate that `Inertia` does _NOT_ implement clone. We, on the compiler, should
// be able to statically analyze branches and only allow one single instance of the extractor at a time.
//
// This is also why every single public method here that changes the extension state consumes `self`.
// Additionally this serves another side benefit, which is that we can freely modify the extractor
// in the prop chain future without worrying about two different modifications happening at once.
//
// What I mean by this is that in order to modify the extractor at all, you have to go through one of
// the self consuming methods, and the only way to "run" the extractor changes is to await the chain,
// which effectively consumes the entire thing. Please ensure that these contracts are upheld when
// changing anything here or else you may indirectly break the compiler, and cause breaking changes
// with how people's route handlers are implemented.
impl Inertia {
    /// Render the current page with whatever props you have set.
    ///
    /// NOTE: this will not immediately render as you may have some share middleware that will still
    /// be able to morph the response structure.
    pub fn render<'route>(self, component: &'static str) -> PropChain<'route> {
        let extension_for_response = Arc::clone(&self.extension);

        let base_response_fut = Box::pin(async move {
            extension_for_response.set_component(component);
            Some(Extension(extension_for_response).into_response())
        })
            as Pin<Box<dyn Future<Output = Option<Response>> + Send + 'route>>;

        PropChain {
            inertia: self,
            props: VecDeque::from_iter([base_response_fut]),
        }
    }

    /// Produce a response that shares all of your props.
    pub fn share<'route>(self, mut res: Response) -> PropChain<'route> {
        let extension_for_response = Arc::clone(&self.extension);

        let base_response_fut = Box::pin(async move {
            res.extensions_mut().insert(extension_for_response);
            Some(res)
        })
            as Pin<Box<dyn Future<Output = Option<Response>> + Send + 'route>>;

        PropChain {
            inertia: self,
            props: VecDeque::from_iter([base_response_fut]),
        }
    }

    /// Redirects you and returns you back to the referrer.
    /// Any props set automatically will become flashed.
    pub fn back<'route>(self) -> PropChain<'route> {
        let response = match RequestHeaders::from(&self.header_map).referer {
            Some(referer) => {
                self.extension.set_redirect(referer);
                (Extension(Arc::clone(&self.extension))).into_response()
            }
            None => {
                tracing::error!("no referer found in request");
                (StatusCode::BAD_REQUEST, "No referer found in request").into_response()
            }
        };

        let base_response_fut = Box::pin(async move { Some(response) })
            as Pin<Box<dyn Future<Output = Option<Response>> + Send + 'route>>;

        PropChain {
            inertia: self,
            props: VecDeque::from_iter([base_response_fut]),
        }
    }

    fn extension_ptr(&self) -> InertiaExtensionPtr {
        InertiaExtensionPtr(&*self.extension)
    }

    pub fn set_version(&self, version: &'static str) {
        self.extension.set_version(version);
    }

    /// Provides the current version, used for defining cache busting
    /// in your versioning provider
    pub fn version(&self) -> Option<&str> {
        RequestHeaders::from(&self.header_map).x_inertia_version
    }

    pub fn set_html_sandwich(&self, head: String, tail: String) {
        self.extension.set_html_sandwich(head, tail);
    }

    /// Set flash props on the request / response for the next request
    pub fn set_flash_props_for_next_request(&self, cb: impl FnOnce(serde_json::Value)) {
        if let Some(flash_props) = self.extension.take_flash_props() {
            cb(flash_props);
        };
    }

    /// Apply flash props to the current request
    pub fn apply_flash_props(&self, flash_props: serde_json::Value) {
        self.extension.extend_props(flash_props);
    }

    /// When using response middleware to modify the response, you must use this
    /// so that the correct continuation data goes through to allow inertia to work
    pub fn next_response(self, mut res: Response) -> Response {
        res.extensions_mut().insert(self.extension);
        res
    }
}

/// Inertia request headers
#[derive(Clone, Copy)]
pub struct RequestHeaders<'req> {
    /// If true, this indicates that this is a JSON request and
    /// we should respond with JSON and not an HTML body
    x_inertia: bool,
    /// The version used for cache busting
    x_inertia_version: Option<&'req str>,
    /// Partial props requested. This will be comma-delimited by inertia.
    x_inertia_partial_data: Option<&'req str>,
    /// Partial props to exclude. This will be comma-delimited by inertia.
    x_inertia_partial_except: Option<&'req str>,
    /// Will contain some value if this is a partial request. This is required
    /// in addition to `partial_data` and `partial_except` because this header
    /// contains the component that was requesting this partial reload. If there
    /// is a mismatch, and the partial was coming from a different component then
    /// the partial request is wrong and must be rejected.
    x_inertia_partial_component: Option<&'req str>,
    /// Referer used to facilitate `inertia.back`
    referer: Option<&'req str>,
}

impl<'req> From<&'req HeaderMap> for RequestHeaders<'req> {
    fn from(headers: &'req HeaderMap) -> Self {
        Self {
            x_inertia: headers.contains_key("x-inertia"),
            x_inertia_version: headers.get("x-inertia-version").and_then(|v| {
                v.to_str().map(Some).unwrap_or_else(|_| {
                    tracing::error!("x-inertia-version header is not a valid string");
                    None
                })
            }),
            x_inertia_partial_data: headers.get("x-inertia-partial-data").and_then(|v| {
                v.to_str().map(Some).unwrap_or_else(|_| {
                    tracing::error!("x-inertia-partial-data header is not a valid string");
                    None
                })
            }),
            x_inertia_partial_except: headers.get("x-inertia-partial-except").and_then(|v| {
                v.to_str().map(Some).unwrap_or_else(|_| {
                    tracing::error!("x-inertia-partial-except header is not a valid string");
                    None
                })
            }),
            x_inertia_partial_component: headers.get("x-inertia-partial-component").and_then(|v| {
                v.to_str().map(Some).unwrap_or_else(|_| {
                    tracing::error!("x-inertia-partial-component header is not a valid string");
                    None
                })
            }),
            referer: headers.get("referer").and_then(|v| {
                v.to_str().map(Some).unwrap_or_else(|_| {
                    tracing::error!("referer header is not a valid string");
                    None
                })
            }),
        }
    }
}

impl RequestHeaders<'_> {
    // TODO(@lazkindness): investigate how laravel checks this value when dealing
    // with post requests.
    /// Returns true if the partial data header matches the expected page component,
    /// essentially indicating that this is a valid partial request.
    pub fn partial(&self, method: &Method, expecting: Option<&'static str>) -> bool {
        self.x_inertia_partial_component.is_some() && method == Method::GET
            || self.x_inertia_partial_component == expecting
    }

    /// Returns an iterator over the comma-separated values in the partial data header.
    pub fn partial_data(&self) -> Option<impl Iterator<Item = &str>> {
        self.x_inertia_partial_data
            .map(|s| s.split(',').map(|s| s.trim()))
    }

    /// Returns an iterator over the comma-separated values in the partial except header.
    pub fn partial_except(&self) -> Option<impl Iterator<Item = &str>> {
        self.x_inertia_partial_except
            .map(|s| s.split(',').map(|s| s.trim()))
    }

    /// Returns true if the request is an inertia request,
    /// indicating that the response should be in JSON
    pub fn x_inertia(&self) -> bool {
        self.x_inertia
    }
}

/// Inertia response when rendering a page.
#[derive(Serialize, Builder)]
#[builder(builder_type(vis = "pub(crate)"))]
#[serde(rename_all = "camelCase")]
pub struct RenderResponse<'req> {
    /// The component to render
    component: &'static str,
    /// The props to pass to the component
    props: &'req serde_json::Value,
    /// The URL of the page. This is essentially what the router will redirect to,
    /// and is going to be the same URL that fired the request.
    // `req.uri().to_string()`, not `req.path()`!
    url: String,
    /// The version used for cache busting
    version: Option<&'static str>,
    /// True if the history should be encrypted
    encrypt_history: bool,
    /// True if the history should be cleared
    clear_history: bool,
}

pub struct PropChain<'route> {
    inertia: Inertia,
    props: VecDeque<Pin<Box<dyn Future<Output = Option<Response>> + Send + 'route>>>,
}

impl<'route> PropChain<'route> {
    /// Pushes a new prop into the response you are currently building.
    ///
    /// A prop any valid async closure that can optionally be wrapped with
    /// various attributes. See [`crate::prop::Prop`] for information on various
    /// wrappers.
    pub fn prop<F, Fut, P, S>(
        mut self,
        name: &'static str,
        prop: impl Into<PropBuilder<F, S>>,
    ) -> Self
    where
        F: FnOnce() -> Fut + Send + 'route,
        Fut: Future<Output = P> + Send + 'route,
        P: PropControlFlow,
        S: prop_builder::State,
    {
        let prop = prop.into().build();
        let eval = prop
            .eval(name)
            .req_headers(RequestHeaders::from(&self.inertia.header_map))
            .maybe_component(self.inertia.extension.component())
            .method(&self.inertia.method)
            .build();

        let extension = self.inertia.extension_ptr();

        self.props.push_front(Box::pin(async move {
            let extension = extension.deref();
            if let ControlFlow::Break(err) = eval.apply(name, extension).await {
                return Some(err.into_response());
            }

            None
        }));

        self
    }

    pub fn flash_prop<T>(self, name: &'static str, prop: T) -> Self
    where
        T: Serialize,
    {
        let extension = self.inertia.extension_ptr();
        extension.add_flash_prop(
            name,
            serde_json::to_value(&prop).expect("failed to serialize flash prop"),
        );

        self
    }
}

struct InertiaExtensionPtr(*const InertiaExtension);

impl Deref for InertiaExtensionPtr {
    type Target = InertiaExtension;

    fn deref(&self) -> &Self::Target {
        // SAFETY: we want to allow each wrapped prop to hold onto the same extension
        // pointer so that we don't have to spam arc clone on the extension every time its chained.
        // the only way this would be unsound is if the prop chain itself could be moved prior
        // to the future getting unpinned, but this is prevented by the following:
        // 1. prop chain and inertia cannot be cloned
        // 2. zero internal fields of either are public
        // 3. the only way to consume the value of this into a response is by awaiting it,
        // which colloquially pins the future
        unsafe { &*self.0 }
    }
}

// SAFETY: see doc comment of deref impl
unsafe impl Send for InertiaExtensionPtr {}

impl<'route> Future for PropChain<'route> {
    type Output = Response;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        // SAFETY: we never move the boxed futures out of `props`, we just
        // poll them in place, so this unchecked projection is sound
        let this = unsafe { self.get_unchecked_mut() };

        loop {
            // SAFETY: there is always at least one future that eventually returns `Some`,
            // as enforced by the creation of prop chains via the constructor methods
            let fut = unsafe { this.props.front_mut().unwrap_unchecked() };

            match fut.as_mut().poll(cx) {
                // the moment we get a response, end. this will be either a user response from an
                // error, or the "base" extension response.
                Poll::Ready(Some(resp)) => return Poll::Ready(resp),
                Poll::Ready(None) => {
                    // completed without a response – discard and advance.
                    this.props.pop_front();
                    // immediately try the next future in the loop
                }
                Poll::Pending => return Poll::Pending,
            }
        }
    }
}