pub mod runtime;
mod buf;
mod chunk;
mod date;
mod error;
pub use self::{chunk::ChunkReader, error::ServeError};
use std::{
io::SeekFrom,
path::{Component, Path, PathBuf},
};
use http::{
header::{HeaderValue, ACCEPT_RANGES, CONTENT_LENGTH, CONTENT_RANGE, CONTENT_TYPE, LAST_MODIFIED, RANGE},
Method, Request, Response, StatusCode,
};
use mime_guess::mime;
use self::{
buf::buf_write_header,
runtime::{AsyncFs, ChunkRead, Meta},
};
#[cfg(feature = "tokio")]
#[derive(Clone)]
pub struct ServeDir<FS: AsyncFs = runtime::TokioFs> {
chunk_size: usize,
base_path: PathBuf,
async_fs: FS,
}
#[cfg(not(feature = "tokio"))]
#[derive(Clone)]
pub struct ServeDir<FS: AsyncFs> {
chunk_size: usize,
base_path: PathBuf,
async_fs: FS,
}
#[cfg(feature = "default")]
impl ServeDir<runtime::TokioFs> {
pub fn new(path: impl Into<PathBuf>) -> Self {
Self::with_fs(path, runtime::TokioFs)
}
}
#[cfg(feature = "tokio-uring")]
impl ServeDir<runtime::TokioUringFs> {
pub fn new_tokio_uring(path: impl Into<PathBuf>) -> Self {
Self::with_fs(path, runtime::TokioUringFs)
}
}
impl<FS: AsyncFs> ServeDir<FS> {
pub fn with_fs(path: impl Into<PathBuf>, async_fs: FS) -> Self {
Self {
chunk_size: 4096,
base_path: path.into(),
async_fs,
}
}
pub fn chunk_size(&mut self, size: usize) -> &mut Self {
self.chunk_size = size;
self
}
pub async fn serve<Ext>(&self, req: &Request<Ext>) -> Result<Response<ChunkReader<FS::File>>, ServeError> {
if !matches!(*req.method(), Method::HEAD | Method::GET) {
return Err(ServeError::MethodNotAllowed);
}
let path = self.path_check(req.uri().path())?;
if path.is_dir() {
return Err(ServeError::InvalidPath);
}
let ct = mime_guess::from_path(&path)
.first_raw()
.unwrap_or_else(|| mime::APPLICATION_OCTET_STREAM.as_ref());
let mut file = self.async_fs.open(path).await?;
let modified = date::mod_date_check(req, &mut file)?;
let mut res = Response::new(());
let mut size = file.len();
if let Some(range) = req
.headers()
.get(RANGE)
.and_then(|h| h.to_str().ok())
.and_then(|range| http_range_header::parse_range_header(range).ok())
.map(|range| range.validate(size))
{
let (start, end) = range
.map_err(|_| ServeError::RangeNotSatisfied(size))?
.pop()
.expect("http_range_header produced empty range")
.into_inner();
file.seek(SeekFrom::Start(start)).await?;
*res.status_mut() = StatusCode::PARTIAL_CONTENT;
let val = buf_write_header!(0, "bytes {start}-{end}/{size}");
res.headers_mut().insert(CONTENT_RANGE, val);
size = end - start + 1;
}
res.headers_mut().insert(CONTENT_TYPE, HeaderValue::from_static(ct));
res.headers_mut().insert(CONTENT_LENGTH, HeaderValue::from(size));
res.headers_mut()
.insert(ACCEPT_RANGES, HeaderValue::from_static("bytes"));
if let Some(modified) = modified {
let val = date::date_to_header(modified);
res.headers_mut().insert(LAST_MODIFIED, val);
}
let stream = if matches!(*req.method(), Method::HEAD) {
ChunkReader::empty()
} else {
ChunkReader::reader(file, size, self.chunk_size)
};
Ok(res.map(|_| stream))
}
}
impl<FS: AsyncFs> ServeDir<FS> {
fn path_check(&self, path: &str) -> Result<PathBuf, ServeError> {
let path = path.trim_start_matches('/').as_bytes();
let path_decoded = percent_encoding::percent_decode(path)
.decode_utf8()
.map_err(|_| ServeError::InvalidPath)?;
let path_decoded = Path::new(&*path_decoded);
let mut path = self.base_path.clone();
for component in path_decoded.components() {
match component {
Component::Normal(comp) => {
if Path::new(&comp)
.components()
.any(|c| !matches!(c, Component::Normal(_)))
{
return Err(ServeError::InvalidPath);
}
path.push(comp)
}
Component::CurDir => {}
Component::Prefix(_) | Component::RootDir | Component::ParentDir => {
return Err(ServeError::InvalidPath)
}
}
}
Ok(path)
}
}
#[cfg(test)]
mod test {
use core::future::poll_fn;
use futures_core::stream::Stream;
use super::*;
fn assert_send<F: Send>(_: &F) {}
#[tokio::test]
async fn tokio_fs_assert_send() {
let dir = ServeDir::new("sample");
let req = Request::builder().uri("/test.txt").body(()).unwrap();
let fut = dir.serve(&req);
assert_send(&fut);
let res = fut.await.unwrap();
assert_send(&res);
}
#[tokio::test]
async fn method() {
let dir = ServeDir::new("sample");
let req = Request::builder()
.method(Method::POST)
.uri("/test.txt")
.body(())
.unwrap();
let e = dir.serve(&req).await.err().unwrap();
assert!(matches!(e, ServeError::MethodNotAllowed));
}
#[tokio::test]
async fn head_method_body_check() {
let dir = ServeDir::new("sample");
let req = Request::builder()
.method(Method::HEAD)
.uri("/test.txt")
.body(())
.unwrap();
let res = dir.serve(&req).await.unwrap();
assert_eq!(
res.headers().get(CONTENT_LENGTH).unwrap(),
HeaderValue::from("hello, world!".len())
);
let mut stream = Box::pin(res.into_body());
assert_eq!(stream.size_hint(), (usize::MAX, Some(0)));
let body_chunk = poll_fn(|cx| stream.as_mut().poll_next(cx)).await;
assert!(body_chunk.is_none())
}
#[tokio::test]
async fn invalid_path() {
let dir = ServeDir::new("sample");
let req = Request::builder().uri("/../test.txt").body(()).unwrap();
assert!(matches!(dir.serve(&req).await.err(), Some(ServeError::InvalidPath)));
}
#[tokio::test]
async fn response_headers() {
let dir = ServeDir::new("sample");
let req = Request::builder().uri("/test.txt").body(()).unwrap();
let res = dir.serve(&req).await.unwrap();
assert_eq!(
res.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("text/plain")
);
assert_eq!(
res.headers().get(ACCEPT_RANGES).unwrap(),
HeaderValue::from_static("bytes")
);
assert_eq!(
res.headers().get(CONTENT_LENGTH).unwrap(),
HeaderValue::from("hello, world!".len())
);
}
#[tokio::test]
async fn body_size_hint() {
let dir = ServeDir::new("sample");
let req = Request::builder().uri("/test.txt").body(()).unwrap();
let res = dir.serve(&req).await.unwrap();
let (lower, Some(upper)) = res.body().size_hint() else {
panic!("ChunkReadStream does not have a size")
};
assert_eq!(lower, upper);
assert_eq!(lower, "hello, world!".len());
}
async fn _basic<FS: AsyncFs>(dir: ServeDir<FS>) {
let req = Request::builder().uri("/test.txt").body(()).unwrap();
let mut stream = Box::pin(dir.serve(&req).await.unwrap().into_body());
let (low, high) = stream.size_hint();
assert_eq!(low, high.unwrap());
assert_eq!(low, "hello, world!".len());
let mut res = String::new();
while let Some(Ok(bytes)) = poll_fn(|cx| stream.as_mut().poll_next(cx)).await {
res.push_str(std::str::from_utf8(bytes.as_ref()).unwrap());
}
assert_eq!(res, "hello, world!");
}
#[tokio::test]
async fn basic() {
_basic(ServeDir::new("sample")).await;
}
#[cfg(all(target_os = "linux", feature = "tokio-uring"))]
#[test]
fn basic_tokio_uring() {
tokio_uring::start(_basic(ServeDir::new_tokio_uring("sample")));
}
async fn test_range<FS: AsyncFs>(dir: ServeDir<FS>) {
let req = Request::builder()
.uri("/test.txt")
.header("range", "bytes=2-12")
.body(())
.unwrap();
let res = dir.serve(&req).await.unwrap();
assert_eq!(res.status(), StatusCode::PARTIAL_CONTENT);
assert_eq!(
res.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("text/plain")
);
assert_eq!(
res.headers().get(CONTENT_RANGE).unwrap(),
HeaderValue::from_static("bytes 2-12/13")
);
assert_eq!(
res.headers().get(CONTENT_LENGTH).unwrap(),
HeaderValue::from("llo, world!".len())
);
let mut stream = Box::pin(res.into_body());
let mut res = String::new();
while let Some(Ok(bytes)) = poll_fn(|cx| stream.as_mut().poll_next(cx)).await {
res.push_str(std::str::from_utf8(bytes.as_ref()).unwrap());
}
assert_eq!("llo, world!", res);
}
#[tokio::test]
async fn ranged() {
test_range(ServeDir::new("sample")).await;
}
#[cfg(all(target_os = "linux", feature = "tokio-uring"))]
#[test]
fn ranged_tokio_uring() {
tokio_uring::start(test_range(ServeDir::new_tokio_uring("sample")))
}
}