use std::{
fs,
marker::PhantomData,
path::{Path, PathBuf},
};
use http::{header::HeaderValue, status::StatusCode};
use motore::service::Service;
use super::FileResponse;
use crate::{context::ServerContext, request::Request, response::Response, server::IntoResponse};
pub struct ServeDir<E, F> {
path: PathBuf,
mime_getter: F,
_marker: PhantomData<fn(E)>,
}
impl<E> ServeDir<E, fn(&Path) -> HeaderValue> {
pub fn new<P>(path: P) -> Self
where
P: AsRef<Path>,
{
let path = fs::canonicalize(path).expect("ServeDir: failed to canonicalize path");
assert!(path.is_dir());
Self {
path,
mime_getter: guess_mime,
_marker: PhantomData,
}
}
pub fn mime_getter<F>(self, mime_getter: F) -> ServeDir<E, F>
where
F: Fn(&Path) -> HeaderValue,
{
ServeDir {
path: self.path,
mime_getter,
_marker: self._marker,
}
}
}
impl<B, E, F> Service<ServerContext, Request<B>> for ServeDir<E, F>
where
B: Send,
F: Fn(&Path) -> HeaderValue + Sync,
{
type Response = Response;
type Error = E;
async fn call(
&self,
_: &mut ServerContext,
req: Request<B>,
) -> Result<Self::Response, Self::Error> {
let path = req.uri().path();
let path = path.strip_prefix('/').unwrap_or(path);
tracing::trace!("[Volo-HTTP] ServeDir: path: {path}");
let path = self.path.join(path);
let Ok(path) = fs::canonicalize(path) else {
return Ok(StatusCode::NOT_FOUND.into_response());
};
if path.strip_prefix(self.path.as_path()).is_err() {
tracing::debug!("[Volo-HTTP] ServeDir: illegal path: {}", path.display());
return Ok(StatusCode::FORBIDDEN.into_response());
}
if !path.is_file() {
return Ok(StatusCode::NOT_FOUND.into_response());
}
let content_type = (self.mime_getter)(&path);
let Ok(resp) = FileResponse::new(path, content_type) else {
return Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response());
};
Ok(resp.into_response())
}
}
pub fn guess_mime(path: &Path) -> HeaderValue {
mime_guess::from_path(path)
.first_raw()
.map(HeaderValue::from_static)
.unwrap_or_else(|| HeaderValue::from_str(mime::APPLICATION_OCTET_STREAM.as_ref()).unwrap())
}
#[cfg(test)]
mod serve_dir_tests {
use http::{StatusCode, method::Method};
use super::ServeDir;
use crate::{
body::Body,
server::{Router, Server},
};
#[tokio::test]
async fn read_file() {
let router: Router<Option<Body>> =
Router::new().nest_service("/static/", ServeDir::new("."));
let server = Server::new(router).into_test_server();
assert!(
server
.call_route(Method::GET, "/static/Cargo.toml", None)
.await
.status()
.is_success()
);
assert!(
server
.call_route(Method::GET, "/static/src/lib.rs", None)
.await
.status()
.is_success()
);
assert_eq!(
server
.call_route(Method::GET, "/static/Cargo.lock", None)
.await
.status(),
StatusCode::NOT_FOUND
);
assert_eq!(
server
.call_route(Method::GET, "/static/../Cargo.toml", None)
.await
.status(),
StatusCode::FORBIDDEN
);
}
}