#[cfg(feature = "validator")]
mod form;
use std::{
convert::Infallible,
ops::{ControlFlow, Deref},
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use axum::{
Extension, RequestPartsExt,
extract::FromRequestParts,
http::{HeaderMap, Method, StatusCode},
response::{IntoResponse, Response},
};
use bon::{Builder, builder};
use pin_project_lite::pin_project;
use serde::Serialize;
use std::future::Future;
use crate::{
extension::InertiaExtension,
prop::{PropBuilder, PropControlFlow, prop_builder},
};
#[derive(Builder)]
pub struct Inertia {
method: Method,
header_map: HeaderMap,
extension: Arc<InertiaExtension>,
}
impl<S> FromRequestParts<S> for Inertia
where
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request_parts(
parts: &mut axum::http::request::Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
let header_map: HeaderMap = parts.extract().await?;
let extension = parts.extensions.get_or_insert_default();
Ok(Self {
method: parts.method.clone(),
header_map,
extension: Arc::clone(extension),
})
}
}
impl Inertia {
pub fn render(self, component: &'static str) -> PropChain<impl Future<Output = Response>> {
let extension_for_response = Arc::clone(&self.extension);
PropChain {
inertia: self,
props: async move {
extension_for_response.set_component(component);
Extension(extension_for_response).into_response()
},
}
}
pub fn share(self, mut res: Response) -> PropChain<impl Future<Output = Response>> {
let extension_for_response = Arc::clone(&self.extension);
PropChain {
inertia: self,
props: async move {
res.extensions_mut().insert(extension_for_response);
res
},
}
}
pub fn back(self) -> PropChain<FlashedResponse> {
let response = match RequestHeaders::from(&self.header_map).referer {
Some(referer) => {
self.extension.set_redirect(referer);
(Extension(Arc::clone(&self.extension))).into_response()
}
None => {
tracing::error!("no referer found in request");
(StatusCode::BAD_REQUEST, "No referer found in request").into_response()
}
};
PropChain {
inertia: self,
props: FlashedResponse(response),
}
}
fn extension_ptr(&self) -> InertiaExtensionPtr {
InertiaExtensionPtr(&*self.extension)
}
pub fn set_version(&self, version: &'static str) {
self.extension.set_version(version);
}
pub fn version(&self) -> Option<&str> {
RequestHeaders::from(&self.header_map).x_inertia_version
}
pub fn set_html_sandwich(&self, head: String, tail: String) {
self.extension.set_html_sandwich(head, tail);
}
pub fn set_flash_props_for_next_request(&self, cb: impl FnOnce(serde_json::Value)) {
if let Some(flash_props) = self.extension.take_flash_props() {
cb(flash_props);
};
}
pub fn apply_flash_props(&self, flash_props: serde_json::Value) {
self.extension.extend_props(flash_props);
}
pub fn next_response(self, mut res: Response) -> Response {
res.extensions_mut().insert(self.extension);
res
}
}
#[derive(Clone, Copy)]
pub struct RequestHeaders<'req> {
x_inertia: bool,
x_inertia_version: Option<&'req str>,
x_inertia_partial_data: Option<&'req str>,
x_inertia_partial_except: Option<&'req str>,
x_inertia_partial_component: Option<&'req str>,
referer: Option<&'req str>,
}
impl<'req> From<&'req HeaderMap> for RequestHeaders<'req> {
fn from(headers: &'req HeaderMap) -> Self {
Self {
x_inertia: headers.contains_key("x-inertia"),
x_inertia_version: headers.get("x-inertia-version").and_then(|v| {
v.to_str().map(Some).unwrap_or_else(|_| {
tracing::error!("x-inertia-version header is not a valid string");
None
})
}),
x_inertia_partial_data: headers.get("x-inertia-partial-data").and_then(|v| {
v.to_str().map(Some).unwrap_or_else(|_| {
tracing::error!("x-inertia-partial-data header is not a valid string");
None
})
}),
x_inertia_partial_except: headers.get("x-inertia-partial-except").and_then(|v| {
v.to_str().map(Some).unwrap_or_else(|_| {
tracing::error!("x-inertia-partial-except header is not a valid string");
None
})
}),
x_inertia_partial_component: headers.get("x-inertia-partial-component").and_then(|v| {
v.to_str().map(Some).unwrap_or_else(|_| {
tracing::error!("x-inertia-partial-component header is not a valid string");
None
})
}),
referer: headers.get("referer").and_then(|v| {
v.to_str().map(Some).unwrap_or_else(|_| {
tracing::error!("referer header is not a valid string");
None
})
}),
}
}
}
impl RequestHeaders<'_> {
pub fn partial(&self, method: &Method, expecting: Option<&'static str>) -> bool {
self.x_inertia_partial_component.is_some() && method == Method::GET
|| self.x_inertia_partial_component == expecting
}
pub fn partial_data(&self) -> Option<impl Iterator<Item = &str>> {
self.x_inertia_partial_data
.map(|s| s.split(',').map(|s| s.trim()))
}
pub fn partial_except(&self) -> Option<impl Iterator<Item = &str>> {
self.x_inertia_partial_except
.map(|s| s.split(',').map(|s| s.trim()))
}
pub fn x_inertia(&self) -> bool {
self.x_inertia
}
}
#[derive(Serialize, Builder)]
#[builder(builder_type(vis = "pub(crate)"))]
#[serde(rename_all = "camelCase")]
pub struct RenderResponse<'req> {
component: &'static str,
props: &'req serde_json::Value,
url: String,
version: Option<&'static str>,
encrypt_history: bool,
clear_history: bool,
}
pin_project! {
pub struct PropChain<F> {
inertia: Inertia,
#[pin]
props: F,
}
}
pub struct FlashedResponse(Response);
impl PropChain<FlashedResponse> {
pub fn flash_prop<T>(self, name: &'static str, prop: T) -> Self
where
T: Serialize,
{
let extension = self.inertia.extension_ptr();
extension.add_flash_prop(
name,
serde_json::to_value(&prop).expect("failed to serialize flash prop"),
);
self
}
}
impl IntoResponse for PropChain<FlashedResponse> {
fn into_response(self) -> Response {
self.props.0
}
}
impl<Fut> PropChain<Fut> {
pub fn prop<F, P, S>(
self,
name: &'static str,
prop: impl Into<PropBuilder<F, S>>,
) -> PropChain<impl Future<Output = Response>>
where
Fut: Future<Output = Response>,
F: AsyncFnOnce() -> P,
P: PropControlFlow,
S: prop_builder::State,
{
let prop = prop.into().build();
let eval = prop
.eval(name)
.req_headers(RequestHeaders::from(&self.inertia.header_map))
.maybe_component(self.inertia.extension.component())
.method(&self.inertia.method)
.build();
let extension = self.inertia.extension_ptr();
let props = async move {
let extension = extension.deref();
if let ControlFlow::Break(err) = eval.apply(name, extension).await {
return err.into_response();
}
self.props.await
};
PropChain {
inertia: self.inertia,
props,
}
}
}
struct InertiaExtensionPtr(*const InertiaExtension);
impl Deref for InertiaExtensionPtr {
type Target = InertiaExtension;
fn deref(&self) -> &Self::Target {
unsafe { &*self.0 }
}
}
unsafe impl Send for InertiaExtensionPtr {}
impl<F> Future for PropChain<F>
where
F: Future<Output = Response>,
{
type Output = Response;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
this.props.poll(cx)
}
}