use std::collections::HashMap;
use std::future::Future;
use std::path::Path;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use bytes::Bytes;
use digest::Digest;
use futures_core::ready;
use http::{Request, header};
use pin_project_lite::pin_project;
use thiserror::Error;
use tower::Service;
use crate::Body;
use crate::config::{StaticFilesConfig, StaticFilesPathRewriteMode};
use crate::error::error_impl::impl_into_cot_error;
use crate::project::MiddlewareContext;
use crate::response::{Response, ResponseExt};
#[macro_export]
macro_rules! static_files {
($($path:literal),* $(,)?) => {
::std::vec![$(
$crate::static_files::StaticFile::new(
$path.to_string(),
$crate::__private::Bytes::from_static(
include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/static/", $path))
),
)
),*]
};
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct StaticFiles {
url_prefix: String,
files: HashMap<String, StaticFileWithMeta>,
rewrite_mode: StaticFilesPathRewriteMode,
cache_timeout: Option<Duration>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct StaticFileWithMeta {
url: String,
file: StaticFile,
}
impl StaticFiles {
#[must_use]
pub(crate) fn new(config: &StaticFilesConfig) -> Self {
Self {
url_prefix: config.url.clone(),
files: HashMap::new(),
rewrite_mode: config.rewrite.clone(),
cache_timeout: config.cache_timeout,
}
}
pub(crate) fn add_file(&mut self, file: StaticFile) {
let path = file.path.clone();
let file = StaticFileWithMeta {
url: self.file_url(&file),
file,
};
self.files.insert(path, file);
}
fn file_url(&self, file: &StaticFile) -> String {
match self.rewrite_mode {
StaticFilesPathRewriteMode::None => {
format!("{}{}", self.url_prefix, file.path.clone())
}
StaticFilesPathRewriteMode::QueryParam => {
format!(
"{}{}?v={}",
self.url_prefix,
file.path.clone(),
Self::file_hash(file)
)
}
}
}
#[must_use]
fn file_hash(file: &StaticFile) -> String {
hex::encode(&sha2::Sha256::digest(&file.content).as_slice()[0..6])
}
#[must_use]
fn get_file(&self, path: &str) -> Option<&StaticFile> {
self.files
.get(path)
.map(|file_with_meta| &file_with_meta.file)
}
#[must_use]
pub(crate) fn path_for(&self, path: &str) -> Option<&str> {
self.files
.get(path)
.map(|file_with_meta| file_with_meta.url.as_str())
}
pub(crate) fn collect_into(&self, path: &Path) -> Result<(), CollectStaticError> {
for (file_path, file_with_meta) in &self.files {
let file_path = path.join(file_path);
std::fs::create_dir_all(
file_path
.parent()
.expect("a joined file path should always have a parent"),
)?;
std::fs::write(file_path, &file_with_meta.file.content)?;
}
Ok(())
}
}
#[derive(Debug, Error)]
#[error("could not collect static files: {0}")]
pub(crate) struct CollectStaticError(#[from] std::io::Error);
impl_into_cot_error!(CollectStaticError);
impl From<&MiddlewareContext> for StaticFiles {
fn from(context: &MiddlewareContext) -> Self {
let mut static_files = StaticFiles::new(&context.config().static_files);
for module in context.apps() {
for file in module.static_files() {
static_files.add_file(file);
}
}
static_files
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct StaticFile {
path: String,
content: Bytes,
mime_type: mime_guess::Mime,
}
impl StaticFile {
#[must_use]
pub fn new<Path, Content>(path: Path, content: Content) -> Self
where
Path: Into<String>,
Content: Into<Bytes>,
{
let path = path.into();
let content = content.into();
let mime_type = mime_guess::from_path(&path).first_or_octet_stream();
Self {
path,
content,
mime_type,
}
}
#[must_use]
fn as_response(&self) -> Response {
Response::builder()
.header(header::CONTENT_TYPE, self.mime_type.to_string())
.body(Body::fixed(self.content.clone()))
.expect("failed to build static file response")
}
}
#[derive(Debug, Clone)]
pub struct StaticFilesMiddleware {
static_files: Arc<StaticFiles>,
}
impl StaticFilesMiddleware {
#[must_use]
pub fn from_context(context: &MiddlewareContext) -> Self {
Self {
static_files: Arc::new(StaticFiles::from(context)),
}
}
}
impl<S> tower::Layer<S> for StaticFilesMiddleware {
type Service = StaticFilesService<S>;
fn layer(&self, inner: S) -> Self::Service {
StaticFilesService::new(Arc::clone(&self.static_files), inner)
}
}
#[derive(Clone, Debug)]
pub struct StaticFilesService<S> {
static_files: Arc<StaticFiles>,
inner: S,
}
impl<S> StaticFilesService<S> {
#[must_use]
fn new(static_files: Arc<StaticFiles>, inner: S) -> Self {
Self {
static_files,
inner,
}
}
}
impl<ReqBody, S> Service<Request<ReqBody>> for StaticFilesService<S>
where
S: Service<Request<ReqBody>, Response = Response>,
{
type Error = S::Error;
type Future = ResponseFuture<S::Future>;
type Response = S::Response;
#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
let path = req.uri().path();
let file_contents =
if let Some(stripped_path) = path.strip_prefix(&self.static_files.url_prefix) {
self.static_files
.get_file(stripped_path)
.map(StaticFile::as_response)
} else {
None
};
if let Some(mut response) = file_contents {
if let Some(timeout) = self.static_files.cache_timeout {
response.headers_mut().insert(
header::CACHE_CONTROL,
header::HeaderValue::from_str(&format!("max-age={}", timeout.as_secs()))
.expect("failed to create cache control header"),
);
}
ResponseFuture::StaticFileResponse { response }
} else {
req.extensions_mut().insert(Arc::clone(&self.static_files));
ResponseFuture::Inner {
future: self.inner.call(req),
}
}
}
}
pin_project! {
#[project = ResponseFutureProj]
#[expect(missing_docs)] pub enum ResponseFuture<F> {
StaticFileResponse {
response: Response,
},
Inner {
#[pin]
future: F,
},
}
}
impl<F, E> Future for ResponseFuture<F>
where
F: Future<Output = Result<Response, E>>,
{
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this {
ResponseFutureProj::StaticFileResponse { response } => {
Poll::Ready(Ok(std::mem::take(response)))
}
ResponseFutureProj::Inner { future } => {
let res = ready!(future.poll(cx)?);
Poll::Ready(Ok(res))
}
}
}
}
#[cfg(test)]
mod tests {
use std::fs;
use http::{Request, StatusCode};
use tower::{Layer, ServiceExt};
use super::*;
use crate::config::{ProjectConfig, StaticFilesConfig, StaticFilesPathRewriteMode};
use crate::project::RegisterAppsContext;
use crate::{App, AppBuilder, Bootstrapper, Project};
#[test]
#[cfg_attr(
miri,
ignore = "unsupported operation: can't call foreign function `sqlite3_open_v2`"
)]
fn static_files_add_and_get_file() {
let mut static_files = StaticFiles::new(&StaticFilesConfig::default());
static_files.add_file(StaticFile::new("test.txt", "This is a test file"));
let file = static_files.get_file("test.txt");
assert!(file.is_some());
assert_eq!(file.unwrap().content, Bytes::from("This is a test file"));
}
#[cot::test]
async fn file_as_response() {
let file = StaticFile {
path: "test.txt".to_owned(),
content: Bytes::from("This is a test file"),
mime_type: mime::TEXT_PLAIN,
};
let response = file.as_response();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.headers()["content-type"], "text/plain");
assert_eq!(
response.into_body().into_bytes().await.unwrap(),
Bytes::from("This is a test file")
);
}
fn create_static_files() -> StaticFiles {
let mut static_files = StaticFiles::new(&StaticFilesConfig::default());
static_files.add_file(StaticFile::new("test.txt", "This is a test file"));
static_files
}
#[cot::test]
async fn static_files_middleware() {
let static_files = Arc::new(create_static_files());
let middleware = StaticFilesMiddleware {
static_files: Arc::clone(&static_files),
};
let service = middleware.layer(tower::service_fn(|_req| async {
Ok::<_, std::convert::Infallible>(Response::new(Body::empty()))
}));
let request = Request::builder()
.uri("/static/test.txt")
.body(Body::empty())
.unwrap();
let response = service.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.headers()["content-type"], "text/plain");
assert_eq!(
response.into_body().into_bytes().await.unwrap(),
Bytes::from("This is a test file")
);
}
#[cot::test]
async fn static_files_middleware_with_config() {
let mut static_files = StaticFiles::new(
&StaticFilesConfig::builder()
.url("/assets/")
.rewrite(StaticFilesPathRewriteMode::QueryParam)
.cache_timeout(Duration::from_secs(300))
.build(),
);
static_files.add_file(StaticFile::new("test.txt", "This is a test file"));
let static_files = Arc::new(static_files);
let middleware = StaticFilesMiddleware {
static_files: Arc::clone(&static_files),
};
let service = middleware.layer(tower::service_fn(|_req| async {
Ok::<_, std::convert::Infallible>(Response::new(Body::empty()))
}));
let url = static_files.path_for("test.txt").unwrap();
assert!(url.starts_with("/assets/test.txt?v="));
let request = Request::builder().uri(url).body(Body::empty()).unwrap();
let response = service.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.headers()["content-type"], "text/plain");
assert_eq!(
response.into_body().into_bytes().await.unwrap(),
Bytes::from("This is a test file")
);
}
#[cot::test]
async fn static_files_middleware_not_found() {
let static_files = Arc::new(create_static_files());
let middleware = StaticFilesMiddleware {
static_files: Arc::clone(&static_files),
};
let service = middleware.layer(tower::service_fn(|_req| async {
Ok::<_, std::convert::Infallible>(Response::new(Body::fixed("test")))
}));
let request = Request::builder()
.uri("/static/nonexistent.txt")
.body(Body::empty())
.unwrap();
let response = service.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.into_body().into_bytes().await.unwrap(),
Bytes::from("test") );
}
#[cot::test]
#[cfg_attr(
miri,
ignore = "unsupported operation: can't call foreign function `sqlite3_open_v2`"
)]
async fn static_files_middleware_from_context() {
struct App1;
impl App for App1 {
fn name(&self) -> &'static str {
"app1"
}
fn static_files(&self) -> Vec<StaticFile> {
static_files!("test/test.txt")
}
}
struct App2;
impl App for App2 {
fn name(&self) -> &'static str {
"app2"
}
fn static_files(&self) -> Vec<StaticFile> {
vec![StaticFile::new("app2/test.js", "test")]
}
}
struct TestProject;
impl Project for TestProject {
fn register_apps(&self, apps: &mut AppBuilder, _context: &RegisterAppsContext) {
apps.register(App1);
apps.register(App2);
}
}
let bootstrapper = Bootstrapper::new(TestProject)
.with_config(ProjectConfig::default())
.with_apps()
.with_database()
.await
.unwrap()
.with_cache()
.await
.unwrap();
let middleware = StaticFilesMiddleware::from_context(bootstrapper.context());
let static_files = middleware.static_files;
let file = static_files.get_file("test/test.txt").unwrap();
assert_eq!(file.mime_type, mime::TEXT_PLAIN);
assert_eq!(
file.content,
Bytes::from_static(include_bytes!("../static/test/test.txt"))
);
let file = static_files.get_file("app2/test.js").unwrap();
assert_eq!(file.content, Bytes::from("test"));
}
#[test]
fn collect_into() {
let temp_dir = tempfile::tempdir().unwrap();
let temp_path = temp_dir.path().to_path_buf();
let mut static_files = StaticFiles::new(&StaticFilesConfig::default());
static_files.add_file(StaticFile::new("test.txt", "This is a test file"));
static_files.add_file(StaticFile::new(
"nested/test2.txt",
"This is another test file",
));
static_files.collect_into(&temp_path).unwrap();
let file_path = temp_path.join("test.txt");
let nested_file_path = temp_path.join("nested/test2.txt");
assert!(file_path.exists());
assert_eq!(
fs::read_to_string(file_path).unwrap(),
"This is a test file"
);
assert!(nested_file_path.exists());
assert_eq!(
fs::read_to_string(nested_file_path).unwrap(),
"This is another test file"
);
}
#[test]
fn collect_into_empty() {
let temp_dir = tempfile::tempdir().unwrap();
let temp_path = temp_dir.path().to_path_buf();
let static_files = StaticFiles::new(&StaticFilesConfig::default());
static_files.collect_into(&temp_path).unwrap();
assert!(fs::read_dir(&temp_path).unwrap().next().is_none());
}
#[test]
fn static_files_macro() {
let static_files = static_files!("test/test.txt");
assert_eq!(static_files.len(), 1);
assert_eq!(static_files[0].path, "test/test.txt");
assert_eq!(
static_files[0].content,
Bytes::from_static(include_bytes!("../static/test/test.txt"))
);
}
#[test]
fn static_files_macro_trailing_comma() {
let static_files = static_files!("test/test.txt",);
assert_eq!(static_files.len(), 1);
}
#[test]
fn static_file_mime_type_detection() {
let file = StaticFile::new("style.css", "body { color: red; }");
assert_eq!(file.mime_type, mime::TEXT_CSS);
let file = StaticFile::new("script.js", "console.log('test');");
assert_eq!(file.mime_type, mime::TEXT_JAVASCRIPT);
let file = StaticFile::new("image.png", "fake image data");
assert_eq!(file.mime_type, mime::IMAGE_PNG);
let file = StaticFile::new("unknown", "some content");
assert_eq!(file.mime_type, mime::APPLICATION_OCTET_STREAM);
}
#[test]
fn static_files_url_rewriting() {
let mut static_files = StaticFiles::new(&StaticFilesConfig {
url: "/static/".to_string(),
rewrite: StaticFilesPathRewriteMode::None,
cache_timeout: None,
});
let file = StaticFile::new("test.txt", "test content");
static_files.add_file(file);
let url = static_files.path_for("test.txt").unwrap();
assert_eq!(url, "/static/test.txt");
let mut static_files = StaticFiles::new(&StaticFilesConfig {
url: "/static/".to_string(),
rewrite: StaticFilesPathRewriteMode::QueryParam,
cache_timeout: None,
});
let file = StaticFile::new("test.txt", "test content");
static_files.add_file(file);
let url = static_files.path_for("test.txt").unwrap();
assert!(url.starts_with("/static/test.txt?v="));
assert_eq!(url.len(), "/static/test.txt?v=".len() + 12); }
#[test]
fn static_files_url_rewriting_with_different_prefix() {
let mut static_files = StaticFiles::new(&StaticFilesConfig {
url: "/assets/".to_string(),
rewrite: StaticFilesPathRewriteMode::QueryParam,
cache_timeout: None,
});
let file = StaticFile::new("images/logo.png", "fake image data");
static_files.add_file(file);
let url = static_files.path_for("images/logo.png").unwrap();
assert!(url.starts_with("/assets/images/logo.png?v="));
}
#[test]
fn static_files_hash_consistency() {
let mut static_files = StaticFiles::new(&StaticFilesConfig {
url: "/static/".to_string(),
rewrite: StaticFilesPathRewriteMode::QueryParam,
cache_timeout: None,
});
let file = StaticFile::new("test.txt", "test content");
static_files.add_file(file);
let url1 = static_files.path_for("test.txt").unwrap().to_owned();
let url2 = static_files.path_for("test.txt").unwrap().to_owned();
assert_eq!(url1, url2);
let file = StaticFile::new("test.txt", "test content");
static_files.add_file(file);
let url3 = static_files.path_for("test.txt").unwrap();
assert_eq!(url1, url3);
}
#[test]
fn static_files_hash_changes_with_content() {
let mut static_files = StaticFiles::new(&StaticFilesConfig {
url: "/static/".to_string(),
rewrite: StaticFilesPathRewriteMode::QueryParam,
cache_timeout: None,
});
let file1 = StaticFile::new("test.txt", "content 1");
static_files.add_file(file1);
let url1 = static_files.path_for("test.txt").unwrap().to_owned();
let file2 = StaticFile::new("test.txt", "content 2");
static_files.add_file(file2);
let url2 = static_files.path_for("test.txt").unwrap();
assert_ne!(url1, url2);
}
}