static_serve/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::convert::Infallible;
4
5use axum::{
6    extract::FromRequestParts,
7    http::{
8        header::{
9            HeaderValue, ACCEPT_ENCODING, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_TYPE, ETAG,
10            IF_NONE_MATCH, VARY,
11        },
12        request::Parts,
13        StatusCode,
14    },
15    response::IntoResponse,
16    routing::{get, MethodRouter},
17    Router,
18};
19use bytes::Bytes;
20
21pub use static_serve_macro::{embed_asset, embed_assets};
22
23/// The accept/reject status for gzip and zstd encoding
24#[derive(Debug, Copy, Clone)]
25struct AcceptEncoding {
26    /// Is gzip accepted?
27    pub gzip: bool,
28    /// Is zstd accepted?
29    pub zstd: bool,
30}
31
32impl<S> FromRequestParts<S> for AcceptEncoding
33where
34    S: Send + Sync,
35{
36    type Rejection = Infallible;
37
38    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
39        let accept_encoding = parts.headers.get(ACCEPT_ENCODING);
40        let accept_encoding = accept_encoding
41            .and_then(|accept_encoding| accept_encoding.to_str().ok())
42            .unwrap_or_default();
43
44        Ok(Self {
45            gzip: accept_encoding.contains("gzip"),
46            zstd: accept_encoding.contains("zstd"),
47        })
48    }
49}
50
51/// Check if the  `IfNoneMatch` header is present
52#[derive(Debug)]
53struct IfNoneMatch(Option<HeaderValue>);
54
55impl IfNoneMatch {
56    /// required function for checking if `IfNoneMatch` is present
57    fn matches(&self, etag: &str) -> bool {
58        self.0
59            .as_ref()
60            .is_some_and(|if_none_match| if_none_match.as_bytes() == etag.as_bytes())
61    }
62}
63
64impl<S> FromRequestParts<S> for IfNoneMatch
65where
66    S: Send + Sync,
67{
68    type Rejection = Infallible;
69
70    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
71        let if_none_match = parts.headers.get(IF_NONE_MATCH).cloned();
72        Ok(Self(if_none_match))
73    }
74}
75
76#[doc(hidden)]
77#[expect(clippy::too_many_arguments)]
78/// The router for adding routes for static assets
79pub fn static_route<S>(
80    router: Router<S>,
81    web_path: &'static str,
82    content_type: &'static str,
83    etag: &'static str,
84    body: &'static [u8],
85    body_gz: Option<&'static [u8]>,
86    body_zst: Option<&'static [u8]>,
87    cache_busted: bool,
88) -> Router<S>
89where
90    S: Clone + Send + Sync + 'static,
91{
92    router.route(
93        web_path,
94        get(
95            move |accept_encoding: AcceptEncoding, if_none_match: IfNoneMatch| async move {
96                static_inner(StaticInnerData {
97                    content_type,
98                    etag,
99                    body,
100                    body_gz,
101                    body_zst,
102                    cache_busted,
103                    accept_encoding,
104                    if_none_match,
105                })
106            },
107        ),
108    )
109}
110
111#[doc(hidden)]
112/// Creates a route for a single static asset.
113///
114/// Used by the `embed_asset!` macro, so it needs to be `pub`.
115pub fn static_method_router(
116    content_type: &'static str,
117    etag: &'static str,
118    body: &'static [u8],
119    body_gz: Option<&'static [u8]>,
120    body_zst: Option<&'static [u8]>,
121    cache_busted: bool,
122) -> MethodRouter {
123    MethodRouter::get(
124        MethodRouter::new(),
125        move |accept_encoding: AcceptEncoding, if_none_match: IfNoneMatch| async move {
126            static_inner(StaticInnerData {
127                content_type,
128                etag,
129                body,
130                body_gz,
131                body_zst,
132                cache_busted,
133                accept_encoding,
134                if_none_match,
135            })
136        },
137    )
138}
139
140/// Struct of parameters for `static_inner` (to avoid `clippy::too_many_arguments`)
141///
142/// This differs from `StaticRouteData` because it
143/// includes the `AcceptEncoding` and `IfNoneMatch` fields
144/// and excludes the `web_path`
145struct StaticInnerData {
146    content_type: &'static str,
147    etag: &'static str,
148    body: &'static [u8],
149    body_gz: Option<&'static [u8]>,
150    body_zst: Option<&'static [u8]>,
151    cache_busted: bool,
152    accept_encoding: AcceptEncoding,
153    if_none_match: IfNoneMatch,
154}
155
156fn static_inner(static_inner_data: StaticInnerData) -> impl IntoResponse {
157    let StaticInnerData {
158        content_type,
159        etag,
160        body,
161        body_gz,
162        body_zst,
163        cache_busted,
164        accept_encoding,
165        if_none_match,
166    } = static_inner_data;
167
168    let optional_cache_control = if cache_busted {
169        Some([(
170            CACHE_CONTROL,
171            HeaderValue::from_static("public, max-age=31536000, immutable"),
172        )])
173    } else {
174        None
175    };
176
177    let resp_base = (
178        [
179            (CONTENT_TYPE, HeaderValue::from_static(content_type)),
180            (ETAG, HeaderValue::from_static(etag)),
181            (VARY, HeaderValue::from_static("Accept-Encoding")),
182        ],
183        optional_cache_control,
184    );
185
186    if if_none_match.matches(etag) {
187        return (resp_base, StatusCode::NOT_MODIFIED).into_response();
188    }
189
190    match (
191        (accept_encoding.gzip, body_gz),
192        (accept_encoding.zstd, body_zst),
193    ) {
194        (_, (true, Some(body_zst))) => (
195            resp_base,
196            [(CONTENT_ENCODING, HeaderValue::from_static("zstd"))],
197            Bytes::from_static(body_zst),
198        )
199            .into_response(),
200        ((true, Some(body_gz)), _) => (
201            resp_base,
202            [(CONTENT_ENCODING, HeaderValue::from_static("gzip"))],
203            Bytes::from_static(body_gz),
204        )
205            .into_response(),
206        _ => (resp_base, Bytes::from_static(body)).into_response(),
207    }
208}