#[cfg(feature = "validator")]
mod form;
use std::{
collections::VecDeque,
convert::Infallible,
ops::{ControlFlow, Deref},
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use axum::{
Extension, RequestPartsExt,
extract::FromRequestParts,
http::{HeaderMap, Method, StatusCode},
response::{IntoResponse, Response},
};
use bon::{Builder, builder};
use serde::Serialize;
use std::future::Future;
use crate::{
extension::InertiaExtension,
prop::{PropBuilder, PropControlFlow, prop_builder},
};
#[derive(Builder)]
pub struct Inertia {
method: Method,
header_map: HeaderMap,
extension: Arc<InertiaExtension>,
}
impl<S> FromRequestParts<S> for Inertia
where
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request_parts(
parts: &mut axum::http::request::Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
let header_map: HeaderMap = parts.extract().await?;
let extension = parts.extensions.get_or_insert_default();
Ok(Self {
method: parts.method.clone(),
header_map,
extension: Arc::clone(extension),
})
}
}
impl Inertia {
pub fn render<'route>(self, component: &'static str) -> PropChain<'route> {
let extension_for_response = Arc::clone(&self.extension);
let base_response_fut = Box::pin(async move {
extension_for_response.set_component(component);
Some(Extension(extension_for_response).into_response())
})
as Pin<Box<dyn Future<Output = Option<Response>> + Send + 'route>>;
PropChain {
inertia: self,
props: VecDeque::from_iter([base_response_fut]),
}
}
pub fn share<'route>(self, mut res: Response) -> PropChain<'route> {
let extension_for_response = Arc::clone(&self.extension);
let base_response_fut = Box::pin(async move {
res.extensions_mut().insert(extension_for_response);
Some(res)
})
as Pin<Box<dyn Future<Output = Option<Response>> + Send + 'route>>;
PropChain {
inertia: self,
props: VecDeque::from_iter([base_response_fut]),
}
}
pub fn back<'route>(self) -> PropChain<'route> {
let response = match RequestHeaders::from(&self.header_map).referer {
Some(referer) => {
self.extension.set_redirect(referer);
(Extension(Arc::clone(&self.extension))).into_response()
}
None => {
tracing::error!("no referer found in request");
(StatusCode::BAD_REQUEST, "No referer found in request").into_response()
}
};
let base_response_fut = Box::pin(async move { Some(response) })
as Pin<Box<dyn Future<Output = Option<Response>> + Send + 'route>>;
PropChain {
inertia: self,
props: VecDeque::from_iter([base_response_fut]),
}
}
fn extension_ptr(&self) -> InertiaExtensionPtr {
InertiaExtensionPtr(&*self.extension)
}
pub fn set_version(&self, version: &'static str) {
self.extension.set_version(version);
}
pub fn version(&self) -> Option<&str> {
RequestHeaders::from(&self.header_map).x_inertia_version
}
pub fn set_html_sandwich(&self, head: String, tail: String) {
self.extension.set_html_sandwich(head, tail);
}
pub fn set_flash_props_for_next_request(&self, cb: impl FnOnce(serde_json::Value)) {
if let Some(flash_props) = self.extension.take_flash_props() {
cb(flash_props);
};
}
pub fn apply_flash_props(&self, flash_props: serde_json::Value) {
self.extension.extend_props(flash_props);
}
pub fn next_response(self, mut res: Response) -> Response {
res.extensions_mut().insert(self.extension);
res
}
}
#[derive(Clone, Copy)]
pub struct RequestHeaders<'req> {
x_inertia: bool,
x_inertia_version: Option<&'req str>,
x_inertia_partial_data: Option<&'req str>,
x_inertia_partial_except: Option<&'req str>,
x_inertia_partial_component: Option<&'req str>,
referer: Option<&'req str>,
}
impl<'req> From<&'req HeaderMap> for RequestHeaders<'req> {
fn from(headers: &'req HeaderMap) -> Self {
Self {
x_inertia: headers.contains_key("x-inertia"),
x_inertia_version: headers.get("x-inertia-version").and_then(|v| {
v.to_str().map(Some).unwrap_or_else(|_| {
tracing::error!("x-inertia-version header is not a valid string");
None
})
}),
x_inertia_partial_data: headers.get("x-inertia-partial-data").and_then(|v| {
v.to_str().map(Some).unwrap_or_else(|_| {
tracing::error!("x-inertia-partial-data header is not a valid string");
None
})
}),
x_inertia_partial_except: headers.get("x-inertia-partial-except").and_then(|v| {
v.to_str().map(Some).unwrap_or_else(|_| {
tracing::error!("x-inertia-partial-except header is not a valid string");
None
})
}),
x_inertia_partial_component: headers.get("x-inertia-partial-component").and_then(|v| {
v.to_str().map(Some).unwrap_or_else(|_| {
tracing::error!("x-inertia-partial-component header is not a valid string");
None
})
}),
referer: headers.get("referer").and_then(|v| {
v.to_str().map(Some).unwrap_or_else(|_| {
tracing::error!("referer header is not a valid string");
None
})
}),
}
}
}
impl RequestHeaders<'_> {
pub fn partial(&self, method: &Method, expecting: Option<&'static str>) -> bool {
self.x_inertia_partial_component.is_some() && method == Method::GET
|| self.x_inertia_partial_component == expecting
}
pub fn partial_data(&self) -> Option<impl Iterator<Item = &str>> {
self.x_inertia_partial_data
.map(|s| s.split(',').map(|s| s.trim()))
}
pub fn partial_except(&self) -> Option<impl Iterator<Item = &str>> {
self.x_inertia_partial_except
.map(|s| s.split(',').map(|s| s.trim()))
}
pub fn x_inertia(&self) -> bool {
self.x_inertia
}
}
#[derive(Serialize, Builder)]
#[builder(builder_type(vis = "pub(crate)"))]
#[serde(rename_all = "camelCase")]
pub struct RenderResponse<'req> {
component: &'static str,
props: &'req serde_json::Value,
url: String,
version: Option<&'static str>,
encrypt_history: bool,
clear_history: bool,
}
pub struct PropChain<'route> {
inertia: Inertia,
props: VecDeque<Pin<Box<dyn Future<Output = Option<Response>> + Send + 'route>>>,
}
impl<'route> PropChain<'route> {
pub fn prop<F, Fut, P, S>(
mut self,
name: &'static str,
prop: impl Into<PropBuilder<F, S>>,
) -> Self
where
F: FnOnce() -> Fut + Send + 'route,
Fut: Future<Output = P> + Send + 'route,
P: PropControlFlow,
S: prop_builder::State,
{
let prop = prop.into().build();
let eval = prop
.eval(name)
.req_headers(RequestHeaders::from(&self.inertia.header_map))
.maybe_component(self.inertia.extension.component())
.method(&self.inertia.method)
.build();
let extension = self.inertia.extension_ptr();
self.props.push_front(Box::pin(async move {
let extension = extension.deref();
if let ControlFlow::Break(err) = eval.apply(name, extension).await {
return Some(err.into_response());
}
None
}));
self
}
pub fn flash_prop<T>(self, name: &'static str, prop: T) -> Self
where
T: Serialize,
{
let extension = self.inertia.extension_ptr();
extension.add_flash_prop(
name,
serde_json::to_value(&prop).expect("failed to serialize flash prop"),
);
self
}
}
struct InertiaExtensionPtr(*const InertiaExtension);
impl Deref for InertiaExtensionPtr {
type Target = InertiaExtension;
fn deref(&self) -> &Self::Target {
unsafe { &*self.0 }
}
}
unsafe impl Send for InertiaExtensionPtr {}
impl<'route> Future for PropChain<'route> {
type Output = Response;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
loop {
let fut = unsafe { this.props.front_mut().unwrap_unchecked() };
match fut.as_mut().poll(cx) {
Poll::Ready(Some(resp)) => return Poll::Ready(resp),
Poll::Ready(None) => {
this.props.pop_front();
}
Poll::Pending => return Poll::Pending,
}
}
}
}