use crate::{
content_encoding::{encodings, SupportedEncodings},
set_status::SetStatus,
};
use async_lock::Mutex;
use bytes::Bytes;
use http::{header, HeaderValue, Method, Request, Response, StatusCode};
use http_body_util::{combinators::UnsyncBoxBody, BodyExt, Empty};
use percent_encoding::percent_decode;
use std::{
convert::Infallible,
io,
path::{Component, Path, PathBuf},
sync::Arc,
};
use tower_async_service::Service;
pub(crate) mod future;
mod headers;
mod open_file;
#[cfg(test)]
mod tests;
const DEFAULT_CAPACITY: usize = 65536;
#[derive(Clone, Debug)]
pub struct ServeDir<F = DefaultServeDirFallback> {
base: PathBuf,
buf_chunk_size: usize,
precompressed_variants: Option<PrecompressedVariants>,
variant: ServeVariant,
fallback: Arc<Mutex<Option<F>>>,
call_fallback_on_method_not_allowed: bool,
}
impl ServeDir<DefaultServeDirFallback> {
pub fn new<P>(path: P) -> Self
where
P: AsRef<Path>,
{
let mut base = PathBuf::from(".");
base.push(path.as_ref());
Self {
base,
buf_chunk_size: DEFAULT_CAPACITY,
precompressed_variants: None,
variant: ServeVariant::Directory {
append_index_html_on_directories: true,
},
fallback: Arc::new(Mutex::new(None)),
call_fallback_on_method_not_allowed: false,
}
}
pub(crate) fn new_single_file<P>(path: P, mime: HeaderValue) -> Self
where
P: AsRef<Path>,
{
Self {
base: path.as_ref().to_owned(),
buf_chunk_size: DEFAULT_CAPACITY,
precompressed_variants: None,
variant: ServeVariant::SingleFile { mime },
fallback: Arc::new(Mutex::new(None)),
call_fallback_on_method_not_allowed: false,
}
}
}
impl<F> ServeDir<F> {
pub fn append_index_html_on_directories(mut self, append: bool) -> Self {
match &mut self.variant {
ServeVariant::Directory {
append_index_html_on_directories,
} => {
*append_index_html_on_directories = append;
self
}
ServeVariant::SingleFile { mime: _ } => self,
}
}
pub fn with_buf_chunk_size(mut self, chunk_size: usize) -> Self {
self.buf_chunk_size = chunk_size;
self
}
pub fn precompressed_gzip(mut self) -> Self {
self.precompressed_variants
.get_or_insert(Default::default())
.gzip = true;
self
}
pub fn precompressed_br(mut self) -> Self {
self.precompressed_variants
.get_or_insert(Default::default())
.br = true;
self
}
pub fn precompressed_deflate(mut self) -> Self {
self.precompressed_variants
.get_or_insert(Default::default())
.deflate = true;
self
}
pub fn precompressed_zstd(mut self) -> Self {
self.precompressed_variants
.get_or_insert(Default::default())
.zstd = true;
self
}
pub fn fallback<F2>(self, new_fallback: F2) -> ServeDir<F2> {
ServeDir {
base: self.base,
buf_chunk_size: self.buf_chunk_size,
precompressed_variants: self.precompressed_variants,
variant: self.variant,
fallback: Arc::new(Mutex::new(Some(new_fallback))),
call_fallback_on_method_not_allowed: self.call_fallback_on_method_not_allowed,
}
}
pub fn not_found_service<F2>(self, new_fallback: F2) -> ServeDir<SetStatus<F2>> {
self.fallback(SetStatus::new(new_fallback, StatusCode::NOT_FOUND))
}
pub fn call_fallback_on_method_not_allowed(mut self, call_fallback: bool) -> Self {
self.call_fallback_on_method_not_allowed = call_fallback;
self
}
pub async fn try_call<ReqBody, FResBody>(
&self,
req: Request<ReqBody>,
) -> Result<Response<ResponseBody>, std::io::Error>
where
F: Service<Request<ReqBody>, Response = Response<FResBody>, Error = Infallible> + Clone,
FResBody: http_body::Body<Data = Bytes> + Send + 'static,
FResBody::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
if req.method() != Method::GET && req.method() != Method::HEAD {
if self.call_fallback_on_method_not_allowed {
if let Some(fallback) = self.fallback.lock().await.as_ref() {
return future::call_fallback(fallback, req).await;
}
} else {
return Ok(future::method_not_allowed());
}
}
let (mut parts, body) = req.into_parts();
let extensions = std::mem::take(&mut parts.extensions);
let req = Request::from_parts(parts, Empty::<Bytes>::new());
let mut fallback_and_request = self.fallback.lock().await.as_mut().map(|fallback| {
let mut fallback_req = Request::new(body);
*fallback_req.method_mut() = req.method().clone();
*fallback_req.uri_mut() = req.uri().clone();
*fallback_req.headers_mut() = req.headers().clone();
*fallback_req.extensions_mut() = extensions;
let clone = fallback.clone();
let fallback = std::mem::replace(fallback, clone);
(fallback, fallback_req)
});
let path_to_file = match self
.variant
.build_and_validate_path(&self.base, req.uri().path())
{
Some(path_to_file) => path_to_file,
None => {
return if let Some((fallback, request)) = fallback_and_request.take() {
future::call_fallback(&fallback, request).await
} else {
Ok(future::not_found())
};
}
};
let buf_chunk_size = self.buf_chunk_size;
let range_header = req
.headers()
.get(header::RANGE)
.and_then(|value| value.to_str().ok())
.map(|s| s.to_owned());
let negotiated_encodings = encodings(
req.headers(),
self.precompressed_variants.unwrap_or_default(),
);
let variant = self.variant.clone();
let open_file_result = open_file::open_file(
variant,
path_to_file,
req,
negotiated_encodings,
range_header,
buf_chunk_size,
)
.await;
future::consume_open_file_result(open_file_result, fallback_and_request).await
}
}
impl<ReqBody, F, FResBody> Service<Request<ReqBody>> for ServeDir<F>
where
F: Service<Request<ReqBody>, Response = Response<FResBody>, Error = Infallible> + Clone,
FResBody: http_body::Body<Data = Bytes> + Send + 'static,
FResBody::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
type Response = Response<ResponseBody>;
type Error = Infallible;
async fn call(&self, req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
let result = self.try_call(req).await;
Ok(result.unwrap_or_else(|err| {
tracing::error!(error = %err, "Failed to read file");
let body = ResponseBody::new(Empty::new().map_err(|err| match err {}).boxed_unsync());
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(body)
.unwrap()
}))
}
}
#[derive(Clone, Debug)]
enum ServeVariant {
Directory {
append_index_html_on_directories: bool,
},
SingleFile {
mime: HeaderValue,
},
}
impl ServeVariant {
fn build_and_validate_path(&self, base_path: &Path, requested_path: &str) -> Option<PathBuf> {
match self {
ServeVariant::Directory {
append_index_html_on_directories: _,
} => {
let path = requested_path.trim_start_matches('/');
let path_decoded = percent_decode(path.as_ref()).decode_utf8().ok()?;
let path_decoded = Path::new(&*path_decoded);
let mut path_to_file = base_path.to_path_buf();
for component in path_decoded.components() {
match component {
Component::Normal(comp) => {
if Path::new(&comp)
.components()
.all(|c| matches!(c, Component::Normal(_)))
{
path_to_file.push(comp)
} else {
return None;
}
}
Component::CurDir => {}
Component::Prefix(_) | Component::RootDir | Component::ParentDir => {
return None;
}
}
}
Some(path_to_file)
}
ServeVariant::SingleFile { mime: _ } => Some(base_path.to_path_buf()),
}
}
}
opaque_body! {
#[derive(Default)]
pub type ResponseBody = UnsyncBoxBody<Bytes, io::Error>;
}
#[derive(Debug, Clone, Copy)]
pub struct DefaultServeDirFallback(Infallible);
impl<ReqBody> Service<Request<ReqBody>> for DefaultServeDirFallback
where
ReqBody: Send + 'static,
{
type Response = Response<ResponseBody>;
type Error = Infallible;
async fn call(&self, _req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
match self.0 {}
}
}
#[derive(Clone, Copy, Debug, Default)]
struct PrecompressedVariants {
gzip: bool,
deflate: bool,
br: bool,
zstd: bool,
}
impl SupportedEncodings for PrecompressedVariants {
fn gzip(&self) -> bool {
self.gzip
}
fn deflate(&self) -> bool {
self.deflate
}
fn br(&self) -> bool {
self.br
}
fn zstd(&self) -> bool {
self.zstd
}
}