use std::borrow::Cow;
use std::path::{Component, Path, PathBuf};
use std::str::FromStr;
use async_compression::futures::bufread::{BrotliEncoder, GzipEncoder};
use async_fs::{File, metadata};
use async_trait::async_trait;
use bytes::Bytes;
use futures::io::{AsyncRead, AsyncReadExt, BufReader};
use futures_util::StreamExt;
use futures_util::stream::{self, BoxStream};
use headers::ContentType;
use http::header::CONTENT_LENGTH;
use mime::CHARSET;
use crate::prelude::stream_body;
use crate::{Handler, Request, Response, SilentError, StatusCode};
use super::StaticOptions;
use super::compression::{Compression, apply_headers, negotiate};
use super::directory::render_directory_listing;
pub struct HandlerWrapperStatic {
root: PathBuf,
options: StaticOptions,
}
impl HandlerWrapperStatic {
fn new(path: &str, options: StaticOptions) -> Self {
Self::try_new(path, options).unwrap_or_else(|_| {
let mut normalized = path;
if normalized.ends_with('/') && normalized.len() > 1 {
normalized = normalized.trim_end_matches('/');
}
panic!("Path not exists: {normalized}");
})
}
pub fn try_new(path: &str, options: StaticOptions) -> Result<Self, SilentError> {
let normalized = if path.ends_with('/') && path.len() > 1 {
path.trim_end_matches('/')
} else {
path
};
if !std::path::Path::new(normalized).is_dir() {
return Err(SilentError::business_error(
StatusCode::INTERNAL_SERVER_ERROR,
format!("static path not exists: {normalized}"),
));
}
Ok(Self {
root: PathBuf::from(normalized),
options,
})
}
fn decode_param(param: &str) -> Result<String, SilentError> {
urlencoding::decode(param)
.map(Cow::into_owned)
.map_err(|_| SilentError::BusinessError {
code: StatusCode::NOT_FOUND,
msg: "Not Found".to_string(),
})
}
fn sanitize_path_param(trimmed: &str) -> Option<PathBuf> {
let mut sanitized = PathBuf::new();
for component in Path::new(trimmed).components() {
match component {
Component::Normal(seg) => sanitized.push(seg),
Component::CurDir => {}
Component::ParentDir | Component::RootDir | Component::Prefix(_) => return None,
}
}
Some(sanitized)
}
fn normalized_request_path(sanitized: &Path, ends_with_slash: bool) -> String {
let mut parts: Vec<String> = Vec::new();
for component in sanitized.components() {
if let Component::Normal(seg) = component {
parts.push(seg.to_string_lossy().into_owned());
}
}
if parts.is_empty() {
return String::new();
}
let mut s = parts.join("/");
if ends_with_slash {
s.push('/');
}
s
}
}
#[async_trait]
impl Handler for HandlerWrapperStatic {
async fn call(&self, req: Request) -> Result<Response, SilentError> {
if let Ok(file_path) = req.get_path_params::<String>("path") {
let decoded = Self::decode_param(&file_path)?;
let ends_with_slash = decoded.ends_with('/') || decoded.is_empty();
let trimmed = decoded.trim_start_matches('/');
let sanitized =
Self::sanitize_path_param(trimmed).ok_or_else(|| SilentError::BusinessError {
code: StatusCode::NOT_FOUND,
msg: "Not Found".to_string(),
})?;
let normalized = Self::normalized_request_path(&sanitized, ends_with_slash);
let fs_path = self.root.join(&sanitized);
let meta = metadata(&fs_path).await.ok();
if self.options.directory_listing {
let is_dir = ends_with_slash || meta.as_ref().map(|m| m.is_dir()).unwrap_or(false);
if is_dir {
return render_directory_listing(&normalized, fs_path.as_path()).await;
}
}
let mut target_path = fs_path.clone();
if ends_with_slash || meta.as_ref().map(|m| m.is_dir()).unwrap_or(false) {
target_path = target_path.join("index.html");
}
if let Ok(file) = File::open(&target_path).await {
let mut res = Response::empty();
let guessed_mime = mime_guess::from_path(&target_path).first();
res.set_typed_header(normalize_content_type(guessed_mime.clone()));
let stream =
if let Some(kind) = negotiate(&self.options, &req, guessed_mime.as_ref()) {
apply_headers(&mut res, &kind);
match kind {
Compression::Brotli => {
let reader = BufReader::new(file);
to_stream(BrotliEncoder::new(reader))
}
Compression::Gzip => {
let reader = BufReader::new(file);
to_stream(GzipEncoder::new(reader))
}
}
} else {
to_stream(file)
};
res.headers_mut().remove(CONTENT_LENGTH);
res.set_body(stream_body(stream));
return Ok(res);
}
}
Err(SilentError::BusinessError {
code: StatusCode::NOT_FOUND,
msg: "Not Found".to_string(),
})
}
}
fn to_stream<R>(reader: R) -> BoxStream<'static, Result<Bytes, std::io::Error>>
where
R: AsyncRead + Unpin + Send + 'static,
{
const CHUNK_SIZE: usize = 16 * 1024;
let buf = vec![0u8; CHUNK_SIZE];
stream::try_unfold((reader, buf), |(mut reader, mut buf)| async move {
let n = reader.read(&mut buf).await?;
if n == 0 {
Ok(None)
} else {
let bytes = Bytes::copy_from_slice(&buf[..n]);
Ok(Some((bytes, (reader, buf))))
}
})
.boxed()
}
fn normalize_content_type(mime: Option<mime::Mime>) -> ContentType {
match mime {
Some(value) => {
if value.type_() == mime::TEXT && value.get_param(CHARSET).is_none() {
let raw = format!("{}/{}; charset=utf-8", value.type_(), value.subtype());
if let Ok(parsed) = mime::Mime::from_str(&raw) {
ContentType::from(parsed)
} else {
ContentType::text_utf8()
}
} else {
ContentType::from(value)
}
}
None => ContentType::octet_stream(),
}
}
pub fn static_handler(path: &str) -> impl Handler {
HandlerWrapperStatic::new(path, StaticOptions::default())
}
pub fn static_handler_with_options(path: &str, options: StaticOptions) -> impl Handler {
HandlerWrapperStatic::new(path, options)
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use http::header::{ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_TYPE};
use http_body_util::BodyExt;
use crate::core::path_param::PathString;
use crate::prelude::*;
use crate::{Handler, Request, SilentError, StatusCode};
use super::{HandlerWrapperStatic, StaticOptions};
static CONTENT: &str = r#"<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>Silent</title>
</head>
<body>
<h1>我的第一个标题</h1>
<p>我的第一个段落。</p>
</body>
</html>"#;
impl PathParam {
#[cfg(test)]
pub(crate) fn path_owned(value: String) -> Self {
PathParam::Path(PathString::Owned(value))
}
}
fn create_static(path: &str) {
if !std::path::Path::new(path).is_dir() {
std::fs::create_dir(path).unwrap();
std::fs::write(format!("./{path}/index.html"), CONTENT).unwrap();
std::fs::write(format!("./{path}/hello.txt"), "hello").unwrap();
std::fs::create_dir(format!("./{path}/docs")).unwrap();
std::fs::write(format!("./{path}/docs/readme.txt"), "doc").unwrap();
}
}
fn clean_static(path: &str) {
if std::path::Path::new(path).is_dir() {
std::fs::remove_dir_all(path).unwrap();
}
}
#[tokio::test]
async fn test_static() {
let path = "test_static";
create_static(path);
let handler = HandlerWrapperStatic::new(path, StaticOptions::default());
let mut req = Request::default();
req.set_path_params(
"path".to_owned(),
PathParam::path_owned("index.html".to_string()),
);
let mut res = handler.call(req).await.unwrap();
clean_static(path);
assert_eq!(res.status, StatusCode::OK);
assert_eq!(
res.body.frame().await.unwrap().unwrap().data_ref().unwrap(),
&Bytes::from(CONTENT)
);
}
#[tokio::test]
async fn test_static_default() {
let path = "test_static_default";
create_static(path);
let handler = HandlerWrapperStatic::new(path, StaticOptions::default());
let mut req = Request::default();
req.set_path_params("path".to_owned(), PathParam::path_owned(String::new()));
let mut res = handler.call(req).await.unwrap();
clean_static(path);
assert_eq!(res.status, StatusCode::OK);
assert_eq!(
res.body.frame().await.unwrap().unwrap().data_ref().unwrap(),
&Bytes::from(CONTENT)
);
}
#[tokio::test]
async fn test_static_not_found() {
let path = "test_static_not_found";
create_static(path);
let handler = HandlerWrapperStatic::new(path, StaticOptions::default());
let mut req = Request::default();
req.set_path_params(
"path".to_owned(),
PathParam::path_owned("not_found.html".to_string()),
);
let res = handler.call(req).await.unwrap_err();
clean_static(path);
if let SilentError::BusinessError { code, .. } = res {
assert_eq!(code, StatusCode::NOT_FOUND);
} else {
panic!();
}
}
#[tokio::test]
async fn test_directory_listing() {
let path = "test_static_listing";
create_static(path);
let options = StaticOptions::default().with_directory_listing();
let handler = HandlerWrapperStatic::new(path, options);
let mut req = Request::default();
req.set_path_params("path".to_owned(), PathParam::path_owned(String::new()));
let mut res = handler.call(req).await.unwrap();
let body = res
.body
.frame()
.await
.unwrap()
.unwrap()
.data_ref()
.unwrap()
.clone();
let body_str = String::from_utf8(body.to_vec()).unwrap();
clean_static(path);
assert!(body_str.contains("hello.txt"));
assert!(body_str.contains("./"));
assert!(!body_str.contains(">../<"));
}
#[tokio::test]
async fn test_compression_negotiation() {
let path = "test_static_compress";
create_static(path);
let options = StaticOptions::default().with_compression();
let handler = HandlerWrapperStatic::new(path, options);
let mut req = Request::default();
req.headers_mut()
.insert(ACCEPT_ENCODING, "gzip".parse().unwrap());
req.set_path_params(
"path".to_owned(),
PathParam::path_owned("hello.txt".to_string()),
);
let res = handler.call(req).await.unwrap();
clean_static(path);
assert_eq!(
res.headers()
.get(CONTENT_ENCODING)
.unwrap()
.to_str()
.unwrap(),
"gzip"
);
}
#[tokio::test]
async fn test_directory_listing_subdir_has_parent_link() {
let path = "test_static_listing_subdir";
create_static(path);
let options = StaticOptions::default().with_directory_listing();
let handler = HandlerWrapperStatic::new(path, options);
let mut req = Request::default();
req.set_path_params(
"path".to_owned(),
PathParam::path_owned("docs/".to_string()),
);
let mut res = handler.call(req).await.unwrap();
let body = res
.body
.frame()
.await
.unwrap()
.unwrap()
.data_ref()
.unwrap()
.clone();
let body_str = String::from_utf8(body.to_vec()).unwrap();
clean_static(path);
assert!(body_str.contains(">../<"));
}
#[tokio::test]
async fn test_text_content_type_uses_utf8() {
let path = "test_static_text_utf8";
create_static(path);
let handler = HandlerWrapperStatic::new(path, StaticOptions::default());
let mut req = Request::default();
req.set_path_params(
"path".to_owned(),
PathParam::path_owned("hello.txt".to_string()),
);
let res = handler.call(req).await.unwrap();
clean_static(path);
let header = res.headers().get(CONTENT_TYPE).unwrap().to_str().unwrap();
assert!(header.contains("charset=utf-8"));
}
}