use crate::*;
pub trait HtmxRouting<F> {
fn wrap_non_htmx(self, wrapper: F) -> Self;
}
impl<F, Fut, R> HtmxRouting<F> for Router
where
F: Fn(Markup) -> Fut + Clone + Send + 'static,
Fut: Future<Output = R> + Send,
R: IntoResponse,
{
fn wrap_non_htmx(self, wrapper: F) -> Self {
self.route_layer(HtmxLayer::wrap(wrapper))
}
}
#[derive(Clone)]
#[doc(hidden)]
pub struct HtmxLayer<F> {
pub wrapper: F,
}
impl<F> HtmxLayer<F> {
pub fn wrap(wrapper: F) -> Self {
Self { wrapper }
}
}
impl<S, F> tower::Layer<S> for HtmxLayer<F>
where
F: Clone,
{
type Service = HtmxMiddleware<S, F>;
fn layer(&self, inner: S) -> Self::Service {
HtmxMiddleware {
wrapper: self.wrapper.clone(),
inner,
}
}
}
#[doc(hidden)]
#[derive(Clone)]
pub struct HtmxMiddleware<S, F> {
wrapper: F,
inner: S,
}
use core::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use std::boxed::Box;
impl<S, F, Fut, R> tower::Service<Request<Body>> for HtmxMiddleware<S, F>
where
S: tower::Service<Request<Body>, Response = Response> + Send + 'static,
S::Future: Send + 'static,
F: Fn(Markup) -> Fut + Send + Clone + 'static,
Fut: Future<Output = R> + Send,
R: IntoResponse,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<
Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send + 'static>,
>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request<Body>) -> Self::Future {
let not_htmx_request = request.headers().get(axum_htmx::HX_REQUEST).is_none();
let future = self.inner.call(request);
let wrapper = self.wrapper.clone();
Box::pin(async move {
let (mut parts, body) = future.await?.into_parts();
if parts.status.as_u16() != 200 {
return Ok(Response::from_parts(parts, body));
}
parts
.headers
.insert(header::CONTENT_TYPE, HeaderValue::from_static("text/html"));
if not_htmx_request {
let body = axum::body::to_bytes(body, 10000000).await.unwrap();
let content = std::string::String::from_utf8(body.to_vec()).unwrap();
let response_future = wrapper(PreEscaped(content));
let response = response_future.await.into_response();
Ok(response)
} else {
parts.headers.insert(
header::CACHE_CONTROL,
HeaderValue::from_static(
"max-age=0, no-cache, must-revalidate, proxy-revalidate",
),
);
Ok(Response::from_parts(parts, body))
}
})
}
}