use bytes::Bytes;
use futures_util::Stream;
use http::{HeaderMap, Response, StatusCode};
use http_body::{combinators::BoxBody, Body, Empty};
use httpdate::HttpDate;
use pin_project_lite::pin_project;
use std::fs::Metadata;
use std::{ffi::OsStr, path::PathBuf};
use std::{
io,
pin::Pin,
task::{Context, Poll},
time::SystemTime,
};
use tokio::fs::File;
use tokio::io::{AsyncRead, AsyncReadExt, Take};
use tokio_util::io::ReaderStream;
mod serve_dir;
mod serve_file;
const DEFAULT_CAPACITY: usize = 65536;
use crate::content_encoding::{Encoding, QValue, SupportedEncodings};
pub use self::{
serve_dir::{
ResponseBody as ServeFileSystemResponseBody,
ResponseFuture as ServeFileSystemResponseFuture,
ServeDir,
},
serve_file::ServeFile,
};
#[derive(Clone, Copy, Debug, Default)]
struct PrecompressedVariants {
gzip: bool,
deflate: bool,
br: bool,
}
impl SupportedEncodings for PrecompressedVariants {
fn gzip(&self) -> bool {
self.gzip
}
fn deflate(&self) -> bool {
self.deflate
}
fn br(&self) -> bool {
self.br
}
}
fn preferred_encoding(
path: &mut PathBuf,
negotiated_encoding: &[(Encoding, QValue)],
) -> Option<Encoding> {
let preferred_encoding = Encoding::preferred_encoding(negotiated_encoding);
if let Some(file_extension) =
preferred_encoding.and_then(|encoding| encoding.to_file_extension())
{
let new_extension = path
.extension()
.map(|extension| {
let mut os_string = extension.to_os_string();
os_string.push(file_extension);
os_string
})
.unwrap_or_else(|| file_extension.to_os_string());
path.set_extension(new_extension);
}
preferred_encoding
}
async fn open_file_with_fallback(
mut path: PathBuf,
mut negotiated_encoding: Vec<(Encoding, QValue)>,
) -> io::Result<(File, Option<Encoding>)> {
let (file, encoding) = loop {
let encoding = preferred_encoding(&mut path, &negotiated_encoding);
match (File::open(&path).await, encoding) {
(Ok(file), maybe_encoding) => break (file, maybe_encoding),
(Err(err), Some(encoding)) if err.kind() == io::ErrorKind::NotFound => {
path.set_extension(OsStr::new(""));
negotiated_encoding
.retain(|(negotiated_encoding, _)| *negotiated_encoding != encoding);
continue;
}
(Err(err), _) => return Err(err),
};
};
Ok((file, encoding))
}
async fn file_metadata_with_fallback(
mut path: PathBuf,
mut negotiated_encoding: Vec<(Encoding, QValue)>,
) -> io::Result<(Metadata, Option<Encoding>)> {
let (file, encoding) = loop {
let encoding = preferred_encoding(&mut path, &negotiated_encoding);
match (tokio::fs::metadata(&path).await, encoding) {
(Ok(file), maybe_encoding) => break (file, maybe_encoding),
(Err(err), Some(encoding)) if err.kind() == io::ErrorKind::NotFound => {
path.set_extension(OsStr::new(""));
negotiated_encoding
.retain(|(negotiated_encoding, _)| *negotiated_encoding != encoding);
continue;
}
(Err(err), _) => return Err(err),
};
};
Ok((file, encoding))
}
pin_project! {
#[derive(Debug)]
pub struct AsyncReadBody<T> {
#[pin]
reader: ReaderStream<T>,
}
}
impl<T> AsyncReadBody<T>
where
T: AsyncRead,
{
fn with_capacity(read: T, capacity: usize) -> Self {
Self {
reader: ReaderStream::with_capacity(read, capacity),
}
}
fn with_capacity_limited(
read: T,
capacity: usize,
max_read_bytes: u64,
) -> AsyncReadBody<Take<T>> {
AsyncReadBody {
reader: ReaderStream::with_capacity(read.take(max_read_bytes), capacity),
}
}
}
impl<T> Body for AsyncReadBody<T>
where
T: AsyncRead,
{
type Data = Bytes;
type Error = io::Error;
fn poll_data(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
self.project().reader.poll_next(cx)
}
fn poll_trailers(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
Poll::Ready(Ok(None))
}
}
fn response_from_io_error(
err: io::Error,
) -> Result<Response<BoxBody<Bytes, io::Error>>, io::Error> {
match err.kind() {
io::ErrorKind::NotFound | io::ErrorKind::PermissionDenied => {
let res = Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Empty::new().map_err(|err| match err {}).boxed())
.unwrap();
Ok(res)
}
_ => Err(err),
}
}
struct LastModified(HttpDate);
impl From<SystemTime> for LastModified {
fn from(time: SystemTime) -> Self {
LastModified(time.into())
}
}
struct IfUnmodifiedSince(HttpDate);
struct IfModifiedSince(HttpDate);
impl IfModifiedSince {
fn is_modified(&self, last_modified: &LastModified) -> bool {
self.0 < last_modified.0
}
fn from_header_value(value: &http::header::HeaderValue) -> Option<IfModifiedSince> {
std::str::from_utf8(value.as_bytes())
.ok()
.and_then(|value| httpdate::parse_http_date(&value).ok())
.map(|time| IfModifiedSince(time.into()))
}
}
impl IfUnmodifiedSince {
fn precondition_passes(&self, last_modified: &LastModified) -> bool {
self.0 >= last_modified.0
}
fn from_header_value(value: &http::header::HeaderValue) -> Option<IfUnmodifiedSince> {
std::str::from_utf8(value.as_bytes())
.ok()
.and_then(|value| httpdate::parse_http_date(&value).ok())
.map(|time| IfUnmodifiedSince(time.into()))
}
}
fn check_modified_headers(
modified: Option<&LastModified>,
if_unmodified_since: Option<IfUnmodifiedSince>,
if_modified_since: Option<IfModifiedSince>,
) -> Option<StatusCode> {
if let Some(since) = if_unmodified_since {
let precondition = modified
.as_ref()
.map(|time| since.precondition_passes(time))
.unwrap_or(false);
if !precondition {
return Some(StatusCode::PRECONDITION_FAILED);
}
}
if let Some(since) = if_modified_since {
let unmodified = modified
.as_ref()
.map(|time| !since.is_modified(&time))
.unwrap_or(false);
if unmodified {
return Some(StatusCode::NOT_MODIFIED);
}
}
None
}