use axum::{
Router,
body::Body,
http::{HeaderMap, HeaderValue, Method, Request, StatusCode, header},
response::{IntoResponse, Response},
};
use chrono::TimeDelta;
use std::{
collections::{BTreeMap, HashMap},
convert::Infallible,
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Instant,
};
use tower::Service;
use tracing::trace;
use crate::{
ClientInfo, RenderContext, Rendered, RenderedBody,
servable::{Servable, ServableWithRoute},
};
struct Default404 {}
impl Servable for Default404 {
fn head<'a>(
&'a self,
_ctx: &'a RenderContext,
) -> Pin<Box<dyn Future<Output = Rendered<()>> + 'a + Send + Sync>> {
Box::pin(async {
return Rendered {
code: StatusCode::NOT_FOUND,
body: (),
ttl: Some(TimeDelta::days(1)),
headers: HeaderMap::new(),
mime: Some(mime::TEXT_HTML),
private: false,
};
})
}
fn render<'a>(
&'a self,
ctx: &'a RenderContext,
) -> Pin<Box<dyn Future<Output = Rendered<RenderedBody>> + 'a + Send + Sync>> {
Box::pin(async { self.head(ctx).await.with_body(RenderedBody::Empty) })
}
}
#[derive(Clone)]
pub struct ServableRouter {
pages: Arc<HashMap<String, Arc<dyn Servable>>>,
notfound: Arc<dyn Servable>,
}
impl ServableRouter {
#[inline(always)]
pub fn new() -> Self {
Self {
pages: Arc::new(HashMap::new()),
notfound: Arc::new(Default404 {}),
}
}
#[inline(always)]
pub fn with_404<S: Servable + 'static>(mut self, page: S) -> Self {
self.notfound = Arc::new(page);
self
}
#[inline(always)]
pub fn add_page<S: Servable + 'static>(mut self, route: impl Into<String>, page: S) -> Self {
let route = route.into();
if !route.starts_with("/") {
panic!("route must start with /")
};
if route.ends_with("/") && route != "/" {
panic!("route must not end with /")
};
if route.contains("//") {
panic!("route must not contain //")
};
#[expect(clippy::expect_used)]
Arc::get_mut(&mut self.pages)
.expect("add_pages called after service was started")
.insert(route, Arc::new(page));
self
}
#[inline(always)]
pub fn add_page_with_route<S: Servable + 'static>(
self,
servable_with_route: &'static ServableWithRoute<S>,
) -> Self {
self.add_page(servable_with_route.route(), servable_with_route)
}
#[inline(always)]
pub fn into_router<T: Clone + Send + Sync + 'static>(self) -> Router<T> {
Router::new().fallback_service(self)
}
}
impl Service<Request<Body>> for ServableRouter {
type Response = Response;
type Error = Infallible;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
if req.method() != Method::GET && req.method() != Method::HEAD {
let mut headers = HeaderMap::with_capacity(1);
headers.insert(header::ACCEPT, HeaderValue::from_static("GET,HEAD"));
return Box::pin(async {
Ok((StatusCode::METHOD_NOT_ALLOWED, headers).into_response())
});
}
let pages = self.pages.clone();
let notfound = self.notfound.clone();
Box::pin(async move {
let addr = req.extensions().get::<SocketAddr>().copied();
let route = req.uri().path().to_owned();
let headers = req.headers().clone();
let query: BTreeMap<String, String> =
serde_urlencoded::from_str(req.uri().query().unwrap_or("")).unwrap_or_default();
let start = Instant::now();
let client_info = ClientInfo::from_headers(&headers);
let ua = headers
.get("user-agent")
.and_then(|x| x.to_str().ok())
.unwrap_or("");
trace!(
message = "Serving route",
route,
addr = ?addr,
user_agent = ua,
device_type = ?client_info.device_type
);
if (route.ends_with('/') && route != "/") || route.contains("//") {
let mut new_route = route.clone();
while new_route.contains("//") {
new_route = new_route.replace("//", "/");
}
let new_route = new_route.trim_matches('/');
trace!(
message = "Redirecting",
route,
new_route,
addr = ?addr,
user_agent = ua,
device_type = ?client_info.device_type
);
let mut headers = HeaderMap::with_capacity(1);
match HeaderValue::from_str(&format!("/{new_route}")) {
Ok(x) => headers.append(header::LOCATION, x),
Err(_) => return Ok(StatusCode::BAD_REQUEST.into_response()),
};
return Ok((StatusCode::PERMANENT_REDIRECT, headers).into_response());
}
let ctx = RenderContext {
client_info,
route,
query,
};
let page = pages.get(&ctx.route).unwrap_or(¬found);
let mut rend = match req.method() == Method::HEAD {
true => page.head(&ctx).await.with_body(RenderedBody::Empty),
false => page.render(&ctx).await,
};
{
if !rend.headers.contains_key(header::CACHE_CONTROL) {
let max_age = rend.ttl.map(|x| x.num_seconds()).unwrap_or(0).max(0);
let mut value = String::new();
value.push_str(match rend.private {
true => "private, ",
false => "public, ",
});
value.push_str(&format!("max-age={}, ", max_age));
#[expect(clippy::unwrap_used)]
rend.headers.insert(
header::CACHE_CONTROL,
HeaderValue::from_str(value.trim().trim_end_matches(',')).unwrap(),
);
}
if !rend.headers.contains_key("Accept-CH") {
rend.headers
.insert("Accept-CH", HeaderValue::from_static("Sec-CH-UA-Mobile"));
}
if !rend.headers.contains_key(header::CONTENT_TYPE)
&& let Some(mime) = &rend.mime
{
#[expect(clippy::unwrap_used)]
rend.headers.insert(
header::CONTENT_TYPE,
HeaderValue::from_str(mime.as_ref()).unwrap(),
);
}
}
trace!(
message = "Served route",
route = ctx.route,
addr = ?addr,
user_agent = ua,
device_type = ?client_info.device_type,
time_ns = start.elapsed().as_nanos()
);
Ok(match rend.body {
RenderedBody::Static(d) => (rend.code, rend.headers, d).into_response(),
RenderedBody::Bytes(d) => (rend.code, rend.headers, d).into_response(),
RenderedBody::String(s) => (rend.code, rend.headers, s).into_response(),
RenderedBody::Empty => (rend.code, rend.headers).into_response(),
})
})
}
}