use super::AsyncReadBody;
use bytes::Bytes;
use futures_util::ready;
use http::{header, HeaderValue, Request, Response, StatusCode, Uri};
use http_body::{combinators::BoxBody, Body, Empty};
use std::{
future::Future,
io,
path::{Path, PathBuf},
pin::Pin,
task::{Context, Poll},
};
use tokio::fs::File;
use tower_service::Service;
#[derive(Clone, Debug)]
pub struct ServeDir {
base: PathBuf,
append_index_html_on_directories: bool,
}
impl ServeDir {
pub fn new<P: AsRef<Path>>(path: P) -> Self {
let mut base = PathBuf::from(".");
base.push(path.as_ref());
Self {
base,
append_index_html_on_directories: true,
}
}
pub fn append_index_html_on_directories(mut self, append: bool) -> Self {
self.append_index_html_on_directories = append;
self
}
}
impl<ReqBody> Service<Request<ReqBody>> for ServeDir {
type Response = Response<ResponseBody>;
type Error = io::Error;
type Future = ResponseFuture;
#[inline]
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let path = req.uri().path();
let path = path.trim_start_matches('/');
let mut full_path = self.base.clone();
for seg in path.split('/') {
if seg.starts_with("..") || seg.contains('\\') {
return ResponseFuture {
inner: Inner::Invalid,
};
}
full_path.push(seg);
}
let append_index_html_on_directories = self.append_index_html_on_directories;
let uri = req.uri().clone();
let open_file_future = Box::pin(async move {
if !uri.path().ends_with('/') {
if is_dir(&full_path).await {
let location =
HeaderValue::from_str(&append_slash_on_path(uri).to_string()).unwrap();
return Ok(Output::Redirect(location));
}
} else if append_index_html_on_directories && is_dir(&full_path).await {
full_path.push("index.html");
}
let guess = mime_guess::from_path(&full_path);
let mime = guess
.first_raw()
.map(|mime| HeaderValue::from_static(mime))
.unwrap_or_else(|| {
HeaderValue::from_str(mime::APPLICATION_OCTET_STREAM.as_ref()).unwrap()
});
let file = File::open(full_path).await?;
Ok(Output::File(file, mime))
});
ResponseFuture {
inner: Inner::Valid(open_file_future),
}
}
}
async fn is_dir(full_path: &Path) -> bool {
tokio::fs::metadata(full_path)
.await
.map(|m| m.is_dir())
.unwrap_or(false)
}
fn append_slash_on_path(uri: Uri) -> Uri {
let http::uri::Parts {
scheme,
authority,
path_and_query,
..
} = uri.into_parts();
let mut builder = Uri::builder();
if let Some(scheme) = scheme {
builder = builder.scheme(scheme);
}
if let Some(authority) = authority {
builder = builder.authority(authority);
}
if let Some(path_and_query) = path_and_query {
if let Some(query) = path_and_query.query() {
builder = builder.path_and_query(format!("{}/?{}", path_and_query.path(), query));
} else {
builder = builder.path_and_query(format!("{}/", path_and_query.path()));
}
} else {
builder = builder.path_and_query("/");
}
builder.build().unwrap()
}
enum Output {
File(File, HeaderValue),
Redirect(HeaderValue),
}
type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + Sync + 'static>>;
enum Inner {
Valid(BoxFuture<io::Result<Output>>),
Invalid,
}
pub struct ResponseFuture {
inner: Inner,
}
impl Future for ResponseFuture {
type Output = io::Result<Response<ResponseBody>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match &mut self.inner {
Inner::Valid(open_file_future) => {
let (file, mime) = match ready!(Pin::new(open_file_future).poll(cx)) {
Ok(Output::File(file, mime)) => (file, mime),
Ok(Output::Redirect(location)) => {
let res = Response::builder()
.header(http::header::LOCATION, location)
.status(StatusCode::PERMANENT_REDIRECT)
.body(empty_body())
.unwrap();
return Poll::Ready(Ok(res));
}
Err(err) => {
return Poll::Ready(
super::response_from_io_error(err).map(|res| res.map(ResponseBody)),
)
}
};
let body = AsyncReadBody::new(file).boxed();
let body = ResponseBody(body);
let mut res = Response::new(body);
res.headers_mut().insert(header::CONTENT_TYPE, mime);
Poll::Ready(Ok(res))
}
Inner::Invalid => {
let res = Response::builder()
.status(StatusCode::NOT_FOUND)
.body(empty_body())
.unwrap();
Poll::Ready(Ok(res))
}
}
}
}
fn empty_body() -> ResponseBody {
let body = Empty::new().map_err(|err| match err {}).boxed();
ResponseBody(body)
}
opaque_body! {
pub type ResponseBody = BoxBody<Bytes, io::Error>;
}
#[cfg(test)]
mod tests {
#[allow(unused_imports)]
use super::*;
use http::{Request, StatusCode};
use http_body::Body as HttpBody;
use hyper::Body;
use tower::ServiceExt;
#[tokio::test]
async fn basic() {
let svc = ServeDir::new("..");
let req = Request::builder()
.uri("/README.md")
.body(Body::empty())
.unwrap();
let res = svc.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.headers()["content-type"], "text/markdown");
let body = body_into_text(res.into_body()).await;
let contents = std::fs::read_to_string("../README.md").unwrap();
assert_eq!(body, contents);
}
#[tokio::test]
async fn access_to_sub_dirs() {
let svc = ServeDir::new("..");
let req = Request::builder()
.uri("/tower-http/Cargo.toml")
.body(Body::empty())
.unwrap();
let res = svc.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.headers()["content-type"], "text/x-toml");
let body = body_into_text(res.into_body()).await;
let contents = std::fs::read_to_string("Cargo.toml").unwrap();
assert_eq!(body, contents);
}
#[tokio::test]
async fn not_found() {
let svc = ServeDir::new("..");
let req = Request::builder()
.uri("/not-found")
.body(Body::empty())
.unwrap();
let res = svc.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::NOT_FOUND);
assert!(res.headers().get(header::CONTENT_TYPE).is_none());
let body = body_into_text(res.into_body()).await;
assert!(body.is_empty());
}
#[tokio::test]
async fn redirect_to_trailing_slash_on_dir() {
let svc = ServeDir::new(".");
let req = Request::builder().uri("/src").body(Body::empty()).unwrap();
let res = svc.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
let location = &res.headers()[http::header::LOCATION];
assert_eq!(location, "/src/");
}
async fn body_into_text<B>(body: B) -> String
where
B: HttpBody<Data = bytes::Bytes> + Unpin,
B::Error: std::fmt::Debug,
{
let bytes = hyper::body::to_bytes(body).await.unwrap();
String::from_utf8(bytes.to_vec()).unwrap()
}
}