use std::{borrow::Cow, convert::Infallible, future::Future, pin::Pin, sync::Arc, task::Poll};
use axum_core::body::Body;
use axum_core::extract::Request;
use axum_core::response::Response;
use chrono::{DateTime, Utc};
use http::StatusCode;
use rust_embed::RustEmbed;
use tower_service::Service;
#[derive(Clone, RustEmbed)]
#[folder = "src/assets"]
struct DefaultFallback;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum FallbackBehavior {
NotFound,
Redirect,
Ok,
}
#[derive(Debug, Clone)]
pub struct ServeEmbed<E: RustEmbed + Clone> {
_phantom: std::marker::PhantomData<E>,
fallback_file: Arc<Option<String>>,
fallback_behavior: FallbackBehavior,
index_file: Arc<Option<String>>,
}
impl<E: RustEmbed + Clone> ServeEmbed<E> {
pub fn new() -> Self {
Self::with_parameters(
None,
FallbackBehavior::NotFound,
Some("index.html".to_owned()),
)
}
pub fn with_parameters(
fallback_file: Option<String>,
fallback_behavior: FallbackBehavior,
index_file: Option<String>,
) -> Self {
Self {
_phantom: std::marker::PhantomData,
fallback_file: Arc::new(fallback_file),
fallback_behavior,
index_file: Arc::new(index_file),
}
}
}
impl<E: RustEmbed + Clone, T: Send + 'static> Service<http::request::Request<T>> for ServeEmbed<E> {
type Response = Response;
type Error = Infallible;
type Future = ServeFuture<E, T>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: http::request::Request<T>) -> Self::Future {
ServeFuture {
_phantom: std::marker::PhantomData,
fallback_behavior: self.fallback_behavior,
fallback_file: self.fallback_file.clone(),
index_file: self.index_file.clone(),
request: req,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
enum CompressionMethod {
Identity,
Brotli,
Gzip,
Zlib,
}
impl CompressionMethod {
fn extension(self) -> &'static str {
match self {
Self::Identity => "",
Self::Brotli => ".br",
Self::Gzip => ".gz",
Self::Zlib => ".zz",
}
}
}
fn from_acceptable_encoding(acceptable_encoding: Option<&str>) -> Vec<CompressionMethod> {
let mut compression_methods = Vec::new();
let mut identity_found = false;
for acceptable_encoding in acceptable_encoding.unwrap_or("").split(',') {
let acceptable_encoding = acceptable_encoding.trim().split(';').next().unwrap();
if acceptable_encoding == "br" {
compression_methods.push(CompressionMethod::Brotli);
} else if acceptable_encoding == "gzip" {
compression_methods.push(CompressionMethod::Gzip);
} else if acceptable_encoding == "deflate" {
compression_methods.push(CompressionMethod::Zlib);
} else if acceptable_encoding == "identity" {
compression_methods.push(CompressionMethod::Identity);
identity_found = true;
}
}
if !identity_found {
compression_methods.push(CompressionMethod::Identity);
}
compression_methods
}
struct GetFileResult<'a> {
path: Cow<'a, str>,
file: Option<rust_embed::EmbeddedFile>,
should_redirect: Option<String>,
compression_method: CompressionMethod,
is_fallback: bool,
}
#[derive(Debug, Clone)]
pub struct ServeFuture<E: RustEmbed, T> {
_phantom: std::marker::PhantomData<E>,
fallback_behavior: FallbackBehavior,
fallback_file: Arc<Option<String>>,
index_file: Arc<Option<String>>,
request: Request<T>,
}
impl<E: RustEmbed, T> ServeFuture<E, T> {
fn get_file<'a>(
&self,
path: &'a str,
acceptable_encoding: &[CompressionMethod],
) -> GetFileResult<'a> {
let mut path_candidate = Cow::Borrowed(path.trim_start_matches('/'));
if path_candidate == "" {
if let Some(index_file) = self.index_file.as_ref() {
path_candidate = Cow::Owned(index_file.to_string());
}
} else if path_candidate.ends_with('/') {
if let Some(index_file) = self.index_file.as_ref().as_ref() {
let new_path_candidate = format!("{}{}", path_candidate, index_file);
if E::get(&new_path_candidate).is_some() {
path_candidate = Cow::Owned(new_path_candidate);
}
}
} else {
if let Some(index_file) = self.index_file.as_ref().as_ref() {
let new_path_candidate = format!("{}/{}", path_candidate, index_file);
if E::get(&new_path_candidate).is_some() {
return GetFileResult {
path: Cow::Owned(new_path_candidate),
file: None,
should_redirect: Some(format!("/{}/", path_candidate)),
compression_method: CompressionMethod::Identity,
is_fallback: false,
};
}
}
}
let mut file = E::get(&path_candidate);
let mut compressed_method = CompressionMethod::Identity;
if file.is_some() {
for one_method in acceptable_encoding {
if let Some(x) = E::get(&format!("{}{}", path_candidate, one_method.extension())) {
file = Some(x);
compressed_method = *one_method;
break;
}
}
}
GetFileResult {
path: path_candidate,
file,
should_redirect: None,
compression_method: compressed_method,
is_fallback: false,
}
}
fn get_file_with_fallback<'a, 'b: 'a>(
&'b self,
path: &'a str,
acceptable_encoding: &[CompressionMethod],
) -> GetFileResult<'a> {
let first_try = self.get_file(path, acceptable_encoding);
if first_try.file.is_some() || first_try.should_redirect.is_some() {
return first_try;
}
if let Some(fallback_file) = self.fallback_file.as_ref().as_ref() {
if fallback_file != path && self.fallback_behavior == FallbackBehavior::Redirect {
return GetFileResult {
path: Cow::Borrowed(path),
file: None,
should_redirect: Some(format!("/{}", fallback_file)),
compression_method: CompressionMethod::Identity,
is_fallback: true,
};
}
let mut fallback_try = self.get_file(fallback_file, acceptable_encoding);
fallback_try.is_fallback = true;
if fallback_try.file.is_some() {
return fallback_try;
}
}
GetFileResult {
path: Cow::Borrowed("404.html"),
file: DefaultFallback::get("404.html"),
should_redirect: None,
compression_method: CompressionMethod::Identity,
is_fallback: true,
}
}
}
impl<E: RustEmbed, T> Future for ServeFuture<E, T> {
type Output = Result<Response<Body>, Infallible>;
fn poll(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
if self.request.method() != http::Method::GET && self.request.method() != http::Method::HEAD
{
return Poll::Ready(Ok(Response::builder()
.status(StatusCode::METHOD_NOT_ALLOWED)
.header(http::header::CONTENT_TYPE, "text/plain")
.body(Body::from("Method not allowed"))
.unwrap()));
}
let (path, file, compression_method, is_fallback) = match self.get_file_with_fallback(
self.request.uri().path(),
&from_acceptable_encoding(
self.request
.headers()
.get(http::header::ACCEPT_ENCODING)
.map(|x| x.to_str().ok())
.flatten(),
),
) {
GetFileResult {
path,
file: Some(file),
should_redirect: None,
compression_method,
is_fallback,
} => (path, file, compression_method, is_fallback),
GetFileResult {
path: _,
file: _,
should_redirect: Some(should_redirect),
compression_method: _,
is_fallback,
} => {
return Poll::Ready(Ok(Response::builder()
.status(if is_fallback {
StatusCode::TEMPORARY_REDIRECT
} else {
StatusCode::MOVED_PERMANENTLY
})
.header(http::header::LOCATION, should_redirect)
.header(http::header::CONTENT_TYPE, "text/plain")
.body(if is_fallback {
Body::from("Temporary redirect")
} else {
Body::from("Moved permanently")
})
.unwrap()));
}
_ => {
unreachable!();
}
};
if !is_fallback
&& self
.request
.headers()
.get(http::header::IF_NONE_MATCH)
.and_then(|value| {
value
.to_str()
.ok()
.and_then(|value| Some(value.trim_matches('"')))
})
== Some(hash_to_string(&file.metadata.sha256_hash()).as_str())
{
return Poll::Ready(Ok(Response::builder()
.status(StatusCode::NOT_MODIFIED)
.body(Body::empty())
.unwrap()));
}
let mut response_builder = Response::builder()
.header(
http::header::CONTENT_TYPE,
mime_guess::from_path(path.as_ref())
.first_or_octet_stream()
.to_string(),
)
.header(
http::header::ETAG,
hash_to_string(&file.metadata.sha256_hash()),
);
match compression_method {
CompressionMethod::Identity => {}
CompressionMethod::Brotli => {
response_builder = response_builder.header(http::header::CONTENT_ENCODING, "br");
}
CompressionMethod::Gzip => {
response_builder = response_builder.header(http::header::CONTENT_ENCODING, "gzip");
}
CompressionMethod::Zlib => {
response_builder =
response_builder.header(http::header::CONTENT_ENCODING, "deflate");
}
}
if let Some(last_modified) = file.metadata.last_modified() {
response_builder =
response_builder.header(http::header::LAST_MODIFIED, date_to_string(last_modified));
}
if is_fallback && self.fallback_behavior != FallbackBehavior::Ok {
response_builder = response_builder.status(StatusCode::NOT_FOUND);
} else {
response_builder = response_builder.status(StatusCode::OK);
}
Poll::Ready(Ok(response_builder
.body(file.data.to_owned().into())
.unwrap()))
}
}
fn hash_to_string(hash: &[u8; 32]) -> String {
let mut s = String::with_capacity(64);
for byte in hash {
s.push_str(&format!("{:02x}", byte));
}
s
}
fn date_to_string(date: u64) -> String {
DateTime::<Utc>::from_timestamp(date as i64, 0)
.unwrap()
.format("%a, %d %b %Y %H:%M:%S GMT")
.to_string()
}
#[cfg(test)]
mod test;