use std::convert::Infallible;
use std::future::{Ready, ready};
use std::task::{Context, Poll};
use bytes::Bytes;
use http::header::{ALLOW, CACHE_CONTROL, CONTENT_TYPE, ETAG, IF_NONE_MATCH};
use http::{HeaderValue, Method, Request, Response, StatusCode};
use http_body_util::Full;
use tower_service::Service;
use crate::Assets;
#[derive(Clone, Copy, Debug)]
pub struct ServeEmbedded {
assets: &'static Assets,
}
impl ServeEmbedded {
pub(crate) fn new(assets: &'static Assets) -> Self {
Self { assets }
}
fn respond(
&self,
method: &Method,
path: &str,
if_none_match: Option<&HeaderValue>,
) -> Response<Full<Bytes>> {
if method != Method::GET && method != Method::HEAD {
return Response::builder()
.status(StatusCode::METHOD_NOT_ALLOWED)
.header(ALLOW, HeaderValue::from_static("GET, HEAD"))
.body(empty())
.unwrap();
}
let Some(resolved) = self.assets.resolve(path) else {
return Response::builder()
.status(StatusCode::NOT_FOUND)
.body(empty())
.unwrap();
};
let file = resolved.file;
if if_none_match.is_some_and(|inm| etag_matches(inm.as_bytes(), file.etag)) {
let mut builder = Response::builder()
.status(StatusCode::NOT_MODIFIED)
.header(ETAG, file.etag);
if let Some(cache_control) = resolved.cache_control {
builder = builder.header(CACHE_CONTROL, cache_control);
}
return builder.body(empty()).unwrap();
}
let body = if method == Method::HEAD {
empty()
} else {
Full::new(Bytes::from_static(file.bytes))
};
let mut builder = Response::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, file.content_type)
.header(ETAG, file.etag);
if let Some(cache_control) = resolved.cache_control {
builder = builder.header(CACHE_CONTROL, cache_control);
}
builder.body(body).unwrap()
}
}
impl<B> Service<Request<B>> for ServeEmbedded {
type Response = Response<Full<Bytes>>;
type Error = Infallible;
type Future = Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<B>) -> Self::Future {
let resp = self.respond(
req.method(),
req.uri().path(),
req.headers().get(IF_NONE_MATCH),
);
ready(Ok(resp))
}
}
fn empty() -> Full<Bytes> {
Full::new(Bytes::new())
}
fn etag_matches(if_none_match: &[u8], etag: &str) -> bool {
let Ok(header) = std::str::from_utf8(if_none_match) else {
return false;
};
header.split(',').any(|candidate| {
let c = candidate.trim();
c == "*" || c == etag || c.strip_prefix("W/").is_some_and(|weak| weak == etag)
})
}
#[cfg(test)]
mod tests {
use super::etag_matches;
#[test]
fn matches_exact_and_wildcard_and_weak() {
assert!(etag_matches(b"\"abc\"", "\"abc\""));
assert!(etag_matches(b"*", "\"abc\""));
assert!(etag_matches(b"W/\"abc\"", "\"abc\""));
assert!(etag_matches(b"\"x\", \"abc\", \"y\"", "\"abc\""));
assert!(!etag_matches(b"\"nope\"", "\"abc\""));
assert!(!etag_matches(b"", "\"abc\""));
}
}