use std::convert::Infallible;
use std::fmt::Debug;
use std::fs::File;
use std::path::Path;
use std::sync::LazyLock;
use std::{collections::HashMap, path::PathBuf};
#[cfg(feature = "runtime-compression")]
use crate::compression::BrotliLevel;
use crate::compression::{
CompressionStrategy, CompressionStrategyInner, CompressionSupport, MatchedFile,
StaticCompression,
};
#[cfg(feature = "hyper")]
use crate::integration::HyperService;
#[cfg(feature = "tower")]
use crate::integration::{TowerLayer, TowerService};
use crate::etag::EtagCache;
use crate::{Body, ETag, FileEntity, FileHasher, FileInfo, SerdirError};
use http::{header, HeaderMap, HeaderValue, Request, Response, StatusCode};
pub struct ServedDir {
dirpath: PathBuf,
compression_strategy: CompressionStrategyInner,
file_hasher: FileHasher,
strip_prefix: Option<String>,
known_extensions: HashMap<String, HeaderValue>,
default_content_type: HeaderValue,
common_headers: HeaderMap,
append_index_html: bool,
not_found_path: Option<PathBuf>,
etag_cache: EtagCache,
}
impl Debug for ServedDir {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let strategy = match self.compression_strategy {
CompressionStrategyInner::Static(_) => "static",
CompressionStrategyInner::None => "none",
#[cfg(feature = "runtime-compression")]
CompressionStrategyInner::Cached(_) => "cached",
};
f.debug_struct("ServedDir")
.field("dirpath", &self.dirpath)
.field("strip_prefix", &self.strip_prefix)
.field("append_index_html", &self.append_index_html)
.field("not_found_path", &self.not_found_path)
.field("default_content_type", &self.default_content_type)
.field("compression_strategy", &strategy)
.finish()
}
}
static OCTET_STREAM: HeaderValue = HeaderValue::from_static("application/octet-stream");
pub(crate) fn default_hasher(file: &File) -> Result<Option<u64>, std::io::Error> {
let hash = rapidhash::v3::rapidhash_v3_file(file)?;
Ok(Some(hash))
}
impl ServedDir {
pub fn builder(path: impl Into<PathBuf>) -> Result<ServedDirBuilder, SerdirError> {
ServedDirBuilder::new(path.into())
}
pub fn dir(&self) -> &Path {
&self.dirpath
}
pub async fn get(&self, path: &str, req_hdrs: &HeaderMap) -> Result<FileEntity, SerdirError> {
let path = match self.strip_prefix.as_deref() {
Some(prefix) if path == prefix => ".",
Some(prefix) => path
.strip_prefix(prefix)
.ok_or(SerdirError::NotFound(None))?,
None => path,
};
let path = path.strip_prefix('/').unwrap_or(path);
let full_path = self.validate_path(path)?;
let preferred = CompressionSupport::detect(req_hdrs);
let res = self.find_file(&full_path, preferred).await;
let matched_file = match res {
Ok(mf) => mf,
Err(SerdirError::IsDirectory(_)) if self.append_index_html => {
let index_path = full_path.join("index.html");
self.find_file(&index_path, preferred).await?
}
Err(SerdirError::NotFound(_)) if self.not_found_path.is_some() => {
let not_found_path = self.not_found_path.as_ref().unwrap();
let matched_file = self.find_file(not_found_path, preferred).await?;
let entity = self.create_entity(matched_file)?;
return Err(SerdirError::NotFound(Some(entity)));
}
Err(e) => return Err(e),
};
self.create_entity(matched_file)
}
pub async fn get_response<B>(&self, req: &Request<B>) -> Result<Response<Body>, Infallible> {
match self.get(req.uri().path(), req.headers()).await {
Ok(entity) => Ok(entity.serve_request(req, StatusCode::OK)),
Err(SerdirError::NotFound(Some(entity))) => {
Ok(entity.serve_request(req, StatusCode::NOT_FOUND))
}
Err(SerdirError::NotFound(None)) | Err(SerdirError::IsDirectory(_)) => {
Ok(Self::make_status_response(StatusCode::NOT_FOUND))
}
Err(SerdirError::InvalidPath(msg)) => {
log::error!("Invalid path: {msg}");
Ok(Self::make_status_response(StatusCode::BAD_REQUEST))
}
Err(e) => {
log::error!("Internal server error: {e}");
Ok(Self::make_status_response(
StatusCode::INTERNAL_SERVER_ERROR,
))
}
}
}
fn make_status_response(status: StatusCode) -> Response<Body> {
let reason = status.canonical_reason().unwrap_or("Unknown");
Response::builder()
.status(status)
.body(Body::from(reason))
.expect("status response should be valid")
}
fn create_entity(&self, matched_file: MatchedFile) -> Result<FileEntity, SerdirError> {
let content_type = self
.known_extensions
.get(&matched_file.extension)
.cloned()
.unwrap_or_else(|| self.default_content_type.clone());
let mut headers = self.common_headers.clone();
headers.insert(http::header::CONTENT_TYPE, content_type);
if let Some(value) = matched_file.content_encoding.get_header_value() {
headers.insert(header::CONTENT_ENCODING, value);
}
if !self.compression_strategy.is_none() {
headers.insert(header::VARY, HeaderValue::from_static("Accept-Encoding"));
}
let etag = self.calculate_etag(matched_file.file_info, matched_file.file.as_ref())?;
Ok(FileEntity::new_with_metadata(
matched_file.file,
matched_file.file_info,
headers,
etag,
))
}
async fn find_file(
&self,
path: &Path,
preferred: CompressionSupport,
) -> Result<MatchedFile, SerdirError> {
self.compression_strategy.find_file(path, preferred).await
}
fn calculate_etag(
&self,
file_info: FileInfo,
file: &File,
) -> Result<Option<ETag>, std::io::Error> {
if let Some(etag) = self.etag_cache.get(&file_info) {
return Ok(etag);
}
let etag = tokio::task::block_in_place(|| {
(self.file_hasher)(file).map(|hash| hash.map(ETag::from))
})?;
self.etag_cache.insert(file_info, etag);
Ok(etag)
}
fn validate_path(&self, path: &str) -> Result<PathBuf, SerdirError> {
if path.as_bytes().contains(&0) {
return Err(SerdirError::InvalidPath(
"path contains NUL byte".to_string(),
));
}
let mut full_path = self.dirpath.clone();
for component in Path::new(path).components() {
match component {
std::path::Component::Normal(seg) => full_path.push(seg),
std::path::Component::CurDir => {}
std::path::Component::ParentDir => {
return Err(SerdirError::InvalidPath(
"path contains .. segment".to_string(),
));
}
std::path::Component::RootDir => {
return Err(SerdirError::InvalidPath("path is absolute".to_string()));
}
std::path::Component::Prefix(_) => {
return Err(SerdirError::InvalidPath(
"path contains a prefix".to_string(),
));
}
}
}
Ok(full_path)
}
#[cfg(feature = "tower")]
pub fn into_tower_service(self) -> TowerService {
TowerService::new(self)
}
#[cfg(feature = "hyper")]
pub fn into_hyper_service(self) -> HyperService {
HyperService::new(self)
}
#[cfg(feature = "tower")]
pub fn into_tower_layer(self) -> TowerLayer {
TowerLayer::new(self)
}
}
#[derive(Debug)]
pub struct ServedDirBuilder {
dirpath: PathBuf,
compression_strategy: CompressionStrategy,
file_hasher: Option<FileHasher>,
strip_prefix: Option<String>,
known_extensions: HashMap<String, HeaderValue>,
default_content_type: HeaderValue,
common_headers: HeaderMap,
append_index_html: bool,
not_found_path: Option<PathBuf>,
}
impl ServedDirBuilder {
pub fn new(dirpath: impl Into<PathBuf>) -> Result<Self, SerdirError> {
let dirpath = dirpath.into();
if !dirpath.is_dir() {
let msg = format!("path is not a directory: {}", dirpath.display());
return Err(SerdirError::ConfigError(msg));
}
Ok(Self {
dirpath,
compression_strategy: CompressionStrategy::none(),
file_hasher: None,
strip_prefix: None,
known_extensions: Self::default_extensions(),
default_content_type: OCTET_STREAM.clone(),
common_headers: HeaderMap::new(),
append_index_html: false,
not_found_path: None,
})
}
pub fn compression(mut self, strategy: impl Into<CompressionStrategy>) -> Self {
self.compression_strategy = strategy.into();
self
}
pub fn static_compression(self, br: bool, gzip: bool, zstd: bool) -> Self {
let strategy = StaticCompression::none().brotli(br).gzip(gzip).zstd(zstd);
self.compression(strategy)
}
#[cfg(feature = "runtime-compression")]
pub fn cached_compression(self, level: BrotliLevel) -> Self {
use crate::compression::CachedCompression;
let strategy = CachedCompression::new().compression_level(level);
self.compression(strategy)
}
pub fn no_compression(self) -> Self {
self.compression(CompressionStrategy::none())
}
pub fn append_index_html(mut self, append: bool) -> Self {
self.append_index_html = append;
self
}
pub fn file_hasher(mut self, file_hasher: FileHasher) -> Self {
self.file_hasher = Some(file_hasher);
self
}
pub fn strip_prefix(mut self, prefix: impl Into<String>) -> Self {
self.strip_prefix = Some(prefix.into());
self
}
pub fn known_extensions(mut self, extensions: HashMap<String, HeaderValue>) -> Self {
self.known_extensions = extensions;
self
}
pub fn known_extension(
mut self,
extension: impl Into<String>,
content_type: HeaderValue,
) -> Self {
self.known_extensions.insert(extension.into(), content_type);
self
}
pub fn default_content_type(mut self, content_type: HeaderValue) -> Self {
self.default_content_type = content_type;
self
}
pub fn common_header(mut self, name: header::HeaderName, value: HeaderValue) -> Self {
self.common_headers.insert(name, value);
self
}
pub fn not_found_path(mut self, path: impl Into<PathBuf>) -> Result<Self, SerdirError> {
let path = path.into();
if path.is_absolute() || path.has_root() {
return Err(SerdirError::ConfigError(
"not_found_path must be relative".to_string(),
));
}
let full_path = self.dirpath.join(path);
if !full_path.is_file() {
return Err(SerdirError::ConfigError(format!(
"not_found_path is not a file: {}",
full_path.display()
)));
}
self.not_found_path = Some(full_path);
Ok(self)
}
pub fn build(self) -> ServedDir {
ServedDir {
dirpath: self.dirpath,
compression_strategy: self.compression_strategy.into_inner(),
file_hasher: self.file_hasher.unwrap_or(default_hasher),
strip_prefix: self.strip_prefix,
known_extensions: self.known_extensions,
default_content_type: self.default_content_type,
common_headers: self.common_headers,
append_index_html: self.append_index_html,
not_found_path: self.not_found_path,
etag_cache: EtagCache::new(),
}
}
fn default_extensions() -> HashMap<String, HeaderValue> {
static DEFAULT_EXTENSIONS: LazyLock<HashMap<String, HeaderValue>> = LazyLock::new(|| {
let extensions = [
("html", "text/html"),
("htm", "text/html"),
("hxt", "text/html"),
("css", "text/css"),
("js", "text/javascript"),
("es", "text/javascript"),
("ecma", "text/javascript"),
("jsm", "text/javascript"),
("jsx", "text/javascript"),
("png", "image/png"),
("apng", "image/apng"),
("avif", "image/avif"),
("gif", "image/gif"),
("ico", "image/x-icon"),
("jpeg", "image/jpeg"),
("jfif", "image/jpeg"),
("pjpeg", "image/jpeg"),
("pjp", "image/jpeg"),
("jpg", "image/jpeg"),
("svg", "image/svg+xml"),
("tiff", "image/tiff"),
("webp", "image/webp"),
("bmp", "image/bmp"),
("pdf", "application/pdf"),
("zip", "application/zip"),
("gz", "application/gzip"),
("tar", "application/tar"),
("bz", "application/x-bzip"),
("bz2", "application/x-bzip2"),
("xz", "application/x-xz"),
("csv", "text/csv"),
("txt", "text/plain"),
("text", "text/plain"),
("log", "text/plain"),
("md", "text/markdown"),
("markdown", "text/x-markdown"),
("mkd", "text/x-markdown"),
("mp4", "video/mp4"),
("webm", "video/webm"),
("mpeg", "video/mpeg"),
("mpg", "video/mpeg"),
("mpg4", "video/mp4"),
("xml", "application/xml"),
("json", "application/json"),
("yaml", "application/yaml"),
("yml", "application/yaml"),
("toml", "application/toml"),
("ini", "application/ini"),
("ics", "text/calendar"),
("doc", "application/msword"),
(
"docx",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
),
("xls", "application/vnd.ms-excel"),
(
"xlsx",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
),
("ppt", "application/vnd.ms-powerpoint"),
(
"pptx",
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
),
];
extensions
.iter()
.map(|&(ext, ct)| (ext.to_string(), HeaderValue::from_static(ct)))
.collect()
});
DEFAULT_EXTENSIONS.clone()
}
}