use super::AsyncReadBody;
use bytes::Bytes;
use futures_util::ready;
use http::{header, HeaderValue, Response};
use http_body::{combinators::BoxBody, Body};
use mime::Mime;
use std::{
future::Future,
io,
path::{Path, PathBuf},
pin::Pin,
task::{Context, Poll},
};
use tokio::fs::File;
use tower_service::Service;
#[derive(Clone, Debug)]
pub struct ServeFile {
path: PathBuf,
mime: HeaderValue,
}
impl ServeFile {
pub fn new<P: AsRef<Path>>(path: P) -> Self {
let guess = mime_guess::from_path(&path);
let mime = guess
.first_raw()
.map(|mime| HeaderValue::from_static(mime))
.unwrap_or_else(|| {
HeaderValue::from_str(mime::APPLICATION_OCTET_STREAM.as_ref()).unwrap()
});
let path = path.as_ref().to_owned();
Self { path, mime }
}
pub fn new_with_mime<P: AsRef<Path>>(path: P, mime: &Mime) -> Self {
let mime = HeaderValue::from_str(mime.as_ref()).expect("mime isn't a valid header value");
let path = path.as_ref().to_owned();
Self { path, mime }
}
}
impl<R> Service<R> for ServeFile {
type Response = Response<ResponseBody>;
type Error = io::Error;
type Future = ResponseFuture;
#[inline]
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: R) -> Self::Future {
let open_file_future = Box::pin(File::open(self.path.clone()));
ResponseFuture {
open_file_future,
mime: Some(self.mime.clone()),
}
}
}
pub struct ResponseFuture {
open_file_future: Pin<Box<dyn Future<Output = io::Result<File>> + Send + Sync + 'static>>,
mime: Option<HeaderValue>,
}
impl Future for ResponseFuture {
type Output = io::Result<Response<ResponseBody>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let result = ready!(Pin::new(&mut self.open_file_future).poll(cx));
let file = match result {
Ok(file) => file,
Err(err) => {
return Poll::Ready(
super::response_from_io_error(err).map(|res| res.map(ResponseBody)),
)
}
};
let body = AsyncReadBody::new(file).boxed();
let body = ResponseBody(body);
let mut res = Response::new(body);
res.headers_mut()
.insert(header::CONTENT_TYPE, self.mime.take().unwrap());
Poll::Ready(Ok(res))
}
}
opaque_body! {
pub type ResponseBody = BoxBody<Bytes, io::Error>;
}
#[cfg(test)]
mod tests {
#[allow(unused_imports)]
use super::*;
use http::{Request, StatusCode};
use http_body::Body as _;
use hyper::Body;
use tower::ServiceExt;
#[tokio::test]
async fn basic() {
let svc = ServeFile::new("../README.md");
let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
assert_eq!(res.headers()["content-type"], "text/markdown");
let body = res.into_body().data().await.unwrap().unwrap();
let body = String::from_utf8(body.to_vec()).unwrap();
assert!(body.starts_with("# Tower HTTP"));
}
#[tokio::test]
async fn returns_404_if_file_doesnt_exist() {
let svc = ServeFile::new("../this-doesnt-exist.md");
let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
assert_eq!(res.status(), StatusCode::NOT_FOUND);
assert!(res.headers().get(header::CONTENT_TYPE).is_none());
}
}