memory_serve/
lib.rs

1#![allow(clippy::needless_doctest_main)]
2#![doc = include_str!("../README.md")]
3use axum::{
4    http::{HeaderMap, StatusCode},
5    routing::get,
6};
7use std::future::ready;
8use tracing::info;
9
10mod asset;
11mod build;
12mod cache_control;
13mod load;
14mod options;
15mod util;
16
17pub use crate::{
18    asset::Asset,
19    build::{load_directory, load_directory_with_embed, load_names_directories},
20    cache_control::CacheControl,
21};
22
23/// Helper struct to create and configure an axum to serve static files from
24/// memory.
25#[derive(Debug, Default)]
26pub struct MemoryServe {
27    options: options::ServeOptions,
28    assets: &'static [Asset],
29    aliases: Vec<(&'static str, &'static str)>,
30}
31
32impl MemoryServe {
33    /// Initiate a `MemoryServe` instance, taking the output of the `load!`
34    /// macro as an argument. `load!` selects the assets prepared during the
35    /// build step.
36    pub fn new(assets: &'static [Asset]) -> Self {
37        Self {
38            assets,
39            ..Default::default()
40        }
41    }
42
43    /// Which static file to serve on the route "/" (the index)
44    /// The path (or route) should be relative to the directory set with
45    /// the `ASSET_DIR` variable, but prepended with a slash.
46    /// By default this is `Some("/index.html")`
47    pub fn index_file(mut self, index_file: Option<&'static str>) -> Self {
48        self.options.index_file = index_file;
49
50        self
51    }
52
53    /// Whether to serve the corresponding index.html file when a route
54    /// matches a subdirectory
55    pub fn index_on_subdirectories(mut self, enable: bool) -> Self {
56        self.options.index_on_subdirectories = enable;
57
58        self
59    }
60
61    /// Which static file to serve when no other routes are matched, also see
62    /// [fallback](https://docs.rs/axum/latest/axum/routing/struct.Router.html#method.fallback)
63    /// The path (or route) should be relative to the directory set with
64    /// the `ASSET_DIR` variable, but prepended with a slash.
65    /// By default this is `None`, which means axum will return an empty
66    /// response with a HTTP 404 status code when no route matches.
67    pub fn fallback(mut self, fallback: Option<&'static str>) -> Self {
68        self.options.fallback = fallback;
69
70        self
71    }
72
73    /// What HTTP status code to return when a static file is returned by the
74    /// fallback handler.
75    pub fn fallback_status(mut self, fallback_status: StatusCode) -> Self {
76        self.options.fallback_status = fallback_status;
77
78        self
79    }
80
81    /// Whether to enable gzip compression. When set to `true`, clients that
82    /// accept gzip compressed files, but not brotli compressed files,
83    /// are served gzip compressed files.
84    pub fn enable_gzip(mut self, enable_gzip: bool) -> Self {
85        self.options.enable_gzip = enable_gzip;
86
87        self
88    }
89
90    /// Whether to enable brotli compression. When set to `true`, clients that
91    /// accept brotli compressed files are served brotli compressed files.
92    pub fn enable_brotli(mut self, enable_brotli: bool) -> Self {
93        self.options.enable_brotli = enable_brotli;
94
95        self
96    }
97
98    /// Whether to enable clean URLs. When set to `true`, the routing path for
99    /// HTML files will not include the extension so that a file located at
100    /// "/about.html" maps to "/about" instead of "/about.html".
101    pub fn enable_clean_url(mut self, enable_clean_url: bool) -> Self {
102        self.options.enable_clean_url = enable_clean_url;
103
104        self
105    }
106
107    /// The Cache-Control header to set for HTML files.
108    /// See [Cache control](index.html#cache-control) for options.
109    pub fn html_cache_control(mut self, html_cache_control: CacheControl) -> Self {
110        self.options.html_cache_control = html_cache_control;
111
112        self
113    }
114
115    /// Cache header to non-HTML files.
116    /// See [Cache control](index.html#cache-control) for options.
117    pub fn cache_control(mut self, cache_control: CacheControl) -> Self {
118        self.options.cache_control = cache_control;
119
120        self
121    }
122
123    /// Create an alias for a route / file
124    pub fn add_alias(mut self, from: &'static str, to: &'static str) -> Self {
125        self.aliases.push((from, to));
126
127        self
128    }
129
130    /// Create an axum `Router` instance that will serve the included static assets
131    /// Caution! This method leaks memory. It should only be called once (at startup).
132    pub fn into_router<S>(self) -> axum::Router<S>
133    where
134        S: Clone + Send + Sync + 'static,
135    {
136        let mut router = axum::Router::new();
137        let options = Box::leak(Box::new(self.options));
138
139        for asset in self.assets {
140            let (uncompressed_bytes, brotli_bytes, gzip_bytes) = asset.leak_bytes(options);
141
142            if !uncompressed_bytes.is_empty() {
143                if asset.is_compressed {
144                    info!(
145                        "serving {} {} -> {} bytes (compressed)",
146                        asset.route,
147                        uncompressed_bytes.len(),
148                        brotli_bytes.len()
149                    );
150                } else {
151                    info!("serving {} {} bytes", asset.route, uncompressed_bytes.len());
152                }
153            } else {
154                info!("serving {} (dynamically)", asset.route);
155            }
156
157            let handler = |headers: HeaderMap| {
158                ready(asset.handler(
159                    &headers,
160                    StatusCode::OK,
161                    uncompressed_bytes,
162                    brotli_bytes,
163                    gzip_bytes,
164                    options,
165                ))
166            };
167
168            if Some(asset.route) == options.fallback {
169                info!("serving {} as fallback", asset.route);
170
171                router = router.fallback(|headers: HeaderMap| {
172                    ready(asset.handler(
173                        &headers,
174                        options.fallback_status,
175                        uncompressed_bytes,
176                        brotli_bytes,
177                        gzip_bytes,
178                        options,
179                    ))
180                });
181            }
182
183            if let Some(index) = options.index_file {
184                if asset.route == index {
185                    info!("serving {} as index on /", asset.route);
186
187                    router = router.route("/", get(handler));
188                } else if options.index_on_subdirectories && asset.route.ends_with(index) {
189                    let path = &asset.route[..asset.route.len() - index.len()];
190                    info!("serving {} as index on {}", asset.route, path);
191
192                    router = router.route(path, get(handler));
193                }
194            }
195
196            let path = if options.enable_clean_url && asset.route.ends_with(".html") {
197                &asset.route[..asset.route.len() - 5]
198            } else {
199                asset.route
200            };
201            router = router.route(path, get(handler));
202
203            // add all aliases that point to the asset route
204            for (from, to) in self.aliases.iter() {
205                if *to == asset.route {
206                    info!("serving {} on alias {}", asset.route, from);
207
208                    router = router.route(from, get(handler));
209                }
210            }
211        }
212
213        router
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use axum::{
220        Router,
221        body::Body,
222        http::{
223            self, HeaderMap, HeaderName, HeaderValue, Request, StatusCode,
224            header::{self, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH},
225        },
226    };
227    use std::sync::LazyLock;
228    use tower::ServiceExt;
229
230    use crate::{self as memory_serve, Asset, CacheControl, MemoryServe};
231
232    static ASSETS: LazyLock<&'static [Asset]> =
233        LazyLock::new(|| memory_serve::build::load_test_assets("./static"));
234
235    macro_rules! test_load {
236        () => {{ MemoryServe::new(*ASSETS) }};
237    }
238
239    async fn get(
240        router: Router,
241        path: &str,
242        key: &str,
243        value: &str,
244    ) -> (StatusCode, HeaderMap<HeaderValue>) {
245        let response = router
246            .oneshot(
247                Request::builder()
248                    .method(http::Method::GET)
249                    .header(key, value)
250                    .uri(path)
251                    .body(Body::empty())
252                    .unwrap(),
253            )
254            .await
255            .unwrap();
256
257        (response.status(), response.headers().to_owned())
258    }
259
260    fn get_header<'s>(headers: &'s HeaderMap, name: &HeaderName) -> &'s str {
261        headers.get(name).unwrap().to_str().unwrap()
262    }
263
264    #[tokio::test]
265    async fn test_load_assets() {
266        let routes: Vec<&str> = ASSETS.iter().map(|a| a.route).collect();
267        let content_types: Vec<&str> = ASSETS.iter().map(|a| a.content_type).collect();
268        let etags: Vec<&str> = ASSETS.iter().map(|a| a.etag).collect();
269
270        assert_eq!(
271            routes,
272            [
273                "/about.html",
274                "/assets/icon.jpg",
275                "/assets/index.css",
276                "/assets/index.js",
277                "/assets/stars.svg",
278                "/blog/index.html",
279                "/index.html"
280            ]
281        );
282        assert_eq!(
283            content_types,
284            [
285                "text/html",
286                "image/jpeg",
287                "text/css",
288                "text/javascript",
289                "image/svg+xml",
290                "text/html",
291                "text/html"
292            ]
293        );
294        if cfg!(debug_assertions) && !cfg!(feature = "force-embed") {
295            assert_eq!(etags, ["", "", "", "", "", "", ""]);
296        } else {
297            assert_eq!(
298                etags,
299                [
300                    "56a0dcb83ec56b6c967966a1c06c7b1392e261069d0844aa4e910ca5c1e8cf58",
301                    "e64f4683bf82d854df40b7246666f6f0816666ad8cd886a8e159535896eb03d6",
302                    "ec4edeea111c854901385011f403e1259e3f1ba016dcceabb6d566316be3677b",
303                    "86a7fdfd19700843e5f7344a63d27e0b729c2554c8572903ceee71f5658d2ecf",
304                    "bd9dccc152de48cb7bedc35b9748ceeade492f6f904710f9c5d480bd6299cc7d",
305                    "89e9873a8e49f962fe83ad2bfe6ac9b21ef7c1b4040b99c34eb783dccbadebc5",
306                    "0639dc8aac157b58c74f65bbb026b2fd42bc81d9a0a64141df456fa23c214537"
307                ]
308            );
309        }
310    }
311
312    #[tokio::test]
313    async fn if_none_match_handling() {
314        let memory_router = test_load!().into_router();
315        let (code, headers) =
316            get(memory_router.clone(), "/index.html", "accept", "text/html").await;
317        let etag: &str = headers.get(header::ETAG).unwrap().to_str().unwrap();
318
319        assert_eq!(code, 200);
320        assert_eq!(
321            etag,
322            "0639dc8aac157b58c74f65bbb026b2fd42bc81d9a0a64141df456fa23c214537"
323        );
324
325        let (code, headers) = get(memory_router, "/index.html", "If-None-Match", etag).await;
326        let length = get_header(&headers, &CONTENT_LENGTH);
327
328        assert_eq!(code, 304);
329        assert_eq!(length.parse::<i32>().unwrap(), 0);
330    }
331
332    #[tokio::test]
333    async fn brotli_compression() {
334        let memory_router = test_load!().enable_brotli(true).into_router();
335        let (code, headers) = get(
336            memory_router.clone(),
337            "/index.html",
338            "accept-encoding",
339            "br",
340        )
341        .await;
342        let encoding = get_header(&headers, &CONTENT_ENCODING);
343        let length = get_header(&headers, &CONTENT_LENGTH);
344
345        assert_eq!(code, 200);
346        assert_eq!(encoding, "br");
347        assert_eq!(length.parse::<i32>().unwrap(), 178);
348
349        // check disable compression
350        let memory_router = test_load!().enable_brotli(false).into_router();
351        let (code, headers) = get(
352            memory_router.clone(),
353            "/index.html",
354            "accept-encoding",
355            "br",
356        )
357        .await;
358        let length: &str = get_header(&headers, &CONTENT_LENGTH);
359
360        assert_eq!(code, 200);
361        assert_eq!(length.parse::<i32>().unwrap(), 437);
362    }
363
364    #[tokio::test]
365    async fn gzip_compression() {
366        let memory_router = test_load!().enable_gzip(true).into_router();
367        let (code, headers) = get(
368            memory_router.clone(),
369            "/index.html",
370            "accept-encoding",
371            "gzip",
372        )
373        .await;
374
375        let encoding = get_header(&headers, &CONTENT_ENCODING);
376        let length = get_header(&headers, &CONTENT_LENGTH);
377
378        assert_eq!(code, 200);
379        assert_eq!(encoding, "gzip");
380        assert_eq!(length.parse::<i32>().unwrap(), 274);
381
382        // check disable compression
383        let memory_router = test_load!().enable_gzip(false).into_router();
384        let (code, headers) = get(
385            memory_router.clone(),
386            "/index.html",
387            "accept-encoding",
388            "gzip",
389        )
390        .await;
391        let length: &str = get_header(&headers, &CONTENT_LENGTH);
392
393        assert_eq!(code, 200);
394        assert_eq!(length.parse::<i32>().unwrap(), 437);
395    }
396
397    #[tokio::test]
398    async fn index_file() {
399        let memory_router = test_load!().index_file(None).into_router();
400
401        let (code, _) = get(memory_router.clone(), "/", "accept", "*").await;
402        assert_eq!(code, 404);
403
404        let memory_router = test_load!().index_file(Some("/index.html")).into_router();
405
406        let (code, _) = get(memory_router.clone(), "/", "accept", "*").await;
407        assert_eq!(code, 200);
408    }
409
410    #[tokio::test]
411    async fn index_file_on_subdirs() {
412        let memory_router = test_load!()
413            .index_file(Some("/index.html"))
414            .index_on_subdirectories(false)
415            .into_router();
416
417        let (code, _) = get(memory_router.clone(), "/blog", "accept", "*").await;
418        assert_eq!(code, 404);
419
420        let memory_router = test_load!()
421            .index_file(Some("/index.html"))
422            .index_on_subdirectories(true)
423            .into_router();
424
425        let (code, _) = get(memory_router.clone(), "/blog", "accept", "*").await;
426        assert_eq!(code, 200);
427    }
428
429    #[tokio::test]
430    async fn clean_url() {
431        let memory_router = test_load!().enable_clean_url(true).into_router();
432
433        let (code, _) = get(memory_router.clone(), "/about.html", "accept", "*").await;
434        assert_eq!(code, 404);
435
436        let (code, _) = get(memory_router.clone(), "/about", "accept", "*").await;
437        assert_eq!(code, 200);
438    }
439
440    #[tokio::test]
441    async fn fallback() {
442        let memory_router = test_load!().into_router();
443        let (code, _) = get(memory_router.clone(), "/foobar", "accept", "*").await;
444        assert_eq!(code, 404);
445
446        let memory_router = test_load!().fallback(Some("/index.html")).into_router();
447        let (code, headers) = get(memory_router.clone(), "/foobar", "accept", "*").await;
448        let length = get_header(&headers, &CONTENT_LENGTH);
449        assert_eq!(code, 404);
450        assert_eq!(length.parse::<i32>().unwrap(), 437);
451
452        let memory_router = test_load!()
453            .fallback(Some("/index.html"))
454            .fallback_status(StatusCode::OK)
455            .into_router();
456        let (code, headers) = get(memory_router.clone(), "/foobar", "accept", "*").await;
457        let length = get_header(&headers, &CONTENT_LENGTH);
458        assert_eq!(code, 200);
459        assert_eq!(length.parse::<i32>().unwrap(), 437);
460    }
461
462    #[tokio::test]
463    async fn cache_control() {
464        async fn check_cache_control(cache_control: CacheControl, expected: &str) {
465            let memory_router = test_load!().cache_control(cache_control).into_router();
466
467            let (code, headers) =
468                get(memory_router.clone(), "/assets/icon.jpg", "accept", "*").await;
469
470            let cache_control = get_header(&headers, &CACHE_CONTROL);
471            assert_eq!(code, 200);
472            assert_eq!(cache_control, expected);
473        }
474
475        check_cache_control(
476            CacheControl::NoCache,
477            CacheControl::NoCache.as_header().1.to_str().unwrap(),
478        )
479        .await;
480        check_cache_control(
481            CacheControl::Short,
482            CacheControl::Short.as_header().1.to_str().unwrap(),
483        )
484        .await;
485        check_cache_control(
486            CacheControl::Medium,
487            CacheControl::Medium.as_header().1.to_str().unwrap(),
488        )
489        .await;
490        check_cache_control(
491            CacheControl::Long,
492            CacheControl::Long.as_header().1.to_str().unwrap(),
493        )
494        .await;
495
496        async fn check_html_cache_control(cache_control: CacheControl, expected: &str) {
497            let memory_router = test_load!().html_cache_control(cache_control).into_router();
498
499            let (code, headers) = get(memory_router.clone(), "/index.html", "accept", "*").await;
500            let cache_control = get_header(&headers, &CACHE_CONTROL);
501            assert_eq!(code, 200);
502            assert_eq!(cache_control, expected);
503        }
504
505        check_html_cache_control(
506            CacheControl::NoCache,
507            CacheControl::NoCache.as_header().1.to_str().unwrap(),
508        )
509        .await;
510        check_html_cache_control(
511            CacheControl::Short,
512            CacheControl::Short.as_header().1.to_str().unwrap(),
513        )
514        .await;
515        check_html_cache_control(
516            CacheControl::Medium,
517            CacheControl::Medium.as_header().1.to_str().unwrap(),
518        )
519        .await;
520        check_html_cache_control(
521            CacheControl::Long,
522            CacheControl::Long.as_header().1.to_str().unwrap(),
523        )
524        .await;
525    }
526
527    #[tokio::test]
528    async fn aliases() {
529        let memory_router = test_load!()
530            .add_alias("/foobar", "/index.html")
531            .add_alias("/baz", "/index.html")
532            .into_router();
533        let (code, _) = get(memory_router.clone(), "/foobar", "accept", "*").await;
534        assert_eq!(code, 200);
535
536        let (code, _) = get(memory_router.clone(), "/baz", "accept", "*").await;
537        assert_eq!(code, 200);
538
539        let (code, _) = get(memory_router.clone(), "/index.html", "accept", "*").await;
540        assert_eq!(code, 200);
541
542        let (code, _) = get(memory_router.clone(), "/barfoo", "accept", "*").await;
543        assert_eq!(code, 404);
544    }
545}