memory_serve/
lib.rs

1#![doc = include_str!("../README.md")]
2use axum::{
3    http::{HeaderMap, StatusCode},
4    routing::get,
5};
6use std::future::ready;
7use tracing::info;
8mod asset;
9mod cache_control;
10mod util;
11
12#[allow(unused)]
13use crate as memory_serve;
14
15pub use crate::{asset::Asset, cache_control::CacheControl};
16pub use memory_serve_core::{
17    ENV_NAME, load_directory, load_directory_with_embed, load_names_directories,
18};
19pub use memory_serve_macros::load_assets;
20
21#[macro_export]
22macro_rules! from_local_build {
23    () => {{
24        use memory_serve::{Asset, MemoryServe};
25
26        let assets: &[(&str, &[Asset])] =
27            include!(concat!(env!("OUT_DIR"), "/memory_serve_assets.rs"));
28
29        if assets.is_empty() {
30            panic!("No assets found, did you call a load_directory* function from your build.rs?");
31        }
32
33        MemoryServe::new(assets[0].1)
34    }};
35    ($title:expr) => {{
36        use memory_serve::{Asset, MemoryServe};
37
38        let assets: &[(&str, &[Asset])] =
39            include!(concat!(env!("OUT_DIR"), "/memory_serve_assets.rs"));
40
41        let assets = assets
42            .iter()
43            .find(|(n, _)| n == $title)
44            .map(|(_, a)| *a)
45            .unwrap_or_default();
46
47        if assets.is_empty() {
48            panic!("No assets found, did you call a load_directory* function from your build.rs?");
49        }
50
51        MemoryServe::new(assets)
52    }};
53}
54
55#[derive(Debug, Clone, Copy)]
56struct ServeOptions {
57    index_file: Option<&'static str>,
58    index_on_subdirectories: bool,
59    fallback: Option<&'static str>,
60    fallback_status: StatusCode,
61    html_cache_control: CacheControl,
62    cache_control: CacheControl,
63    enable_brotli: bool,
64    enable_gzip: bool,
65    enable_clean_url: bool,
66}
67
68impl Default for ServeOptions {
69    fn default() -> Self {
70        Self {
71            index_file: Some("/index.html"),
72            index_on_subdirectories: false,
73            fallback: None,
74            fallback_status: StatusCode::NOT_FOUND,
75            html_cache_control: CacheControl::Short,
76            cache_control: CacheControl::Medium,
77            enable_brotli: !cfg!(debug_assertions),
78            enable_gzip: !cfg!(debug_assertions),
79            enable_clean_url: false,
80        }
81    }
82}
83
84/// Helper struct to create and configure an axum to serve static files from
85/// memory.
86#[derive(Debug, Default)]
87pub struct MemoryServe {
88    options: ServeOptions,
89    assets: &'static [Asset],
90    aliases: Vec<(&'static str, &'static str)>,
91}
92
93impl MemoryServe {
94    /// Initiate a `MemoryServe` instance, takes the output of `load_assets!`
95    /// as an argument. `load_assets!` takes a directory name relative from
96    /// the project root.
97    pub fn new(assets: &'static [Asset]) -> Self {
98        Self {
99            assets,
100            ..Default::default()
101        }
102    }
103
104    /// Initiate a `MemoryServe` instance, takes the contents of `memory_serve_assets.bin`
105    /// created at build time.
106    /// Specify which asset directory to include using the environment variable `ASSET_DIR`.
107    pub fn from_env() -> Self {
108        let assets: &[(&str, &[Asset])] =
109            include!(concat!(env!("OUT_DIR"), "/memory_serve_assets.rs"));
110
111        if assets.is_empty() {
112            panic!("No assets found, did you forget to set the {ENV_NAME} environment variable?");
113        }
114
115        Self::new(assets[0].1)
116    }
117
118    /// Include a directory using a named environment variable, prefixed by ASSRT_DIR_.
119    /// Specify which asset directory to include using the environment variable `ASSET_DIR_<SOME NAME>`.
120    /// The name should be in uppercase.
121    /// For example to include assets from the public directory using the name PUBLIC, set the enirobment variable
122    /// `ASSET_DIR_PUBLIC=./public` and call `MemoryServe::from_name("PUBLIC")`.
123    pub fn from_env_name(name: &str) -> Self {
124        let assets: &[(&str, &[Asset])] =
125            include!(concat!(env!("OUT_DIR"), "/memory_serve_assets.rs"));
126
127        let assets = assets
128            .iter()
129            .find(|(n, _)| n == &name)
130            .map(|(_, a)| *a)
131            .unwrap_or_default();
132
133        if assets.is_empty() {
134            panic!(
135                "No assets found, did you forget to set the {ENV_NAME}_{name} environment variable?"
136            );
137        }
138
139        Self::new(assets)
140    }
141
142    /// Which static file to serve on the route "/" (the index)
143    /// The path (or route) should be relative to the directory set with
144    /// the `ASSET_DIR` variable, but prepended with a slash.
145    /// By default this is `Some("/index.html")`
146    pub fn index_file(mut self, index_file: Option<&'static str>) -> Self {
147        self.options.index_file = index_file;
148
149        self
150    }
151
152    /// Whether to serve the corresponding index.html file when a route
153    /// matches a subdirectory
154    pub fn index_on_subdirectories(mut self, enable: bool) -> Self {
155        self.options.index_on_subdirectories = enable;
156
157        self
158    }
159
160    /// Which static file to serve when no other routes are matched, also see
161    /// [fallback](https://docs.rs/axum/latest/axum/routing/struct.Router.html#method.fallback)
162    /// The path (or route) should be relative to the directory set with
163    /// the `ASSET_DIR` variable, but prepended with a slash.
164    /// By default this is `None`, which means axum will return an empty
165    /// response with a HTTP 404 status code when no route matches.
166    pub fn fallback(mut self, fallback: Option<&'static str>) -> Self {
167        self.options.fallback = fallback;
168
169        self
170    }
171
172    /// What HTTP status code to return when a static file is returned by the
173    /// fallback handler.
174    pub fn fallback_status(mut self, fallback_status: StatusCode) -> Self {
175        self.options.fallback_status = fallback_status;
176
177        self
178    }
179
180    /// Whether to enable gzip compression. When set to `true`, clients that
181    /// accept gzip compressed files, but not brotli compressed files,
182    /// are served gzip compressed files.
183    pub fn enable_gzip(mut self, enable_gzip: bool) -> Self {
184        self.options.enable_gzip = enable_gzip;
185
186        self
187    }
188
189    /// Whether to enable brotli compression. When set to `true`, clients that
190    /// accept brotli compressed files are served brotli compressed files.
191    pub fn enable_brotli(mut self, enable_brotli: bool) -> Self {
192        self.options.enable_brotli = enable_brotli;
193
194        self
195    }
196
197    /// Whether to enable clean URLs. When set to `true`, the routing path for
198    /// HTML files will not include the extension so that a file located at
199    /// "/about.html" maps to "/about" instead of "/about.html".
200    pub fn enable_clean_url(mut self, enable_clean_url: bool) -> Self {
201        self.options.enable_clean_url = enable_clean_url;
202
203        self
204    }
205
206    /// The Cache-Control header to set for HTML files.
207    /// See [Cache control](index.html#cache-control) for options.
208    pub fn html_cache_control(mut self, html_cache_control: CacheControl) -> Self {
209        self.options.html_cache_control = html_cache_control;
210
211        self
212    }
213
214    /// Cache header to non-HTML files.
215    /// See [Cache control](index.html#cache-control) for options.
216    pub fn cache_control(mut self, cache_control: CacheControl) -> Self {
217        self.options.cache_control = cache_control;
218
219        self
220    }
221
222    /// Create an alias for a route / file
223    pub fn add_alias(mut self, from: &'static str, to: &'static str) -> Self {
224        self.aliases.push((from, to));
225
226        self
227    }
228
229    /// Create an axum `Router` instance that will serve the included static assets
230    /// Caution! This method leaks memory. It should only be called once (at startup).
231    pub fn into_router<S>(self) -> axum::Router<S>
232    where
233        S: Clone + Send + Sync + 'static,
234    {
235        let mut router = axum::Router::new();
236        let options = Box::leak(Box::new(self.options));
237
238        for asset in self.assets {
239            let (uncompressed_bytes, brotli_bytes, gzip_bytes) = asset.leak_bytes(options);
240
241            if !uncompressed_bytes.is_empty() {
242                if asset.is_compressed {
243                    info!(
244                        "serving {} {} -> {} bytes (compressed)",
245                        asset.route,
246                        uncompressed_bytes.len(),
247                        brotli_bytes.len()
248                    );
249                } else {
250                    info!("serving {} {} bytes", asset.route, uncompressed_bytes.len());
251                }
252            } else {
253                info!("serving {} (dynamically)", asset.route);
254            }
255
256            let handler = |headers: HeaderMap| {
257                ready(asset.handler(
258                    &headers,
259                    StatusCode::OK,
260                    uncompressed_bytes,
261                    brotli_bytes,
262                    gzip_bytes,
263                    options,
264                ))
265            };
266
267            if Some(asset.route) == options.fallback {
268                info!("serving {} as fallback", asset.route);
269
270                router = router.fallback(|headers: HeaderMap| {
271                    ready(asset.handler(
272                        &headers,
273                        options.fallback_status,
274                        uncompressed_bytes,
275                        brotli_bytes,
276                        gzip_bytes,
277                        options,
278                    ))
279                });
280            }
281
282            if let Some(index) = options.index_file {
283                if asset.route == index {
284                    info!("serving {} as index on /", asset.route);
285
286                    router = router.route("/", get(handler));
287                } else if options.index_on_subdirectories && asset.route.ends_with(index) {
288                    let path = &asset.route[..asset.route.len() - index.len()];
289                    info!("serving {} as index on {}", asset.route, path);
290
291                    router = router.route(path, get(handler));
292                }
293            }
294
295            let path = if options.enable_clean_url && asset.route.ends_with(".html") {
296                &asset.route[..asset.route.len() - 5]
297            } else {
298                asset.route
299            };
300            router = router.route(path, get(handler));
301
302            // add all aliases that point to the asset route
303            for (from, to) in self.aliases.iter() {
304                if *to == asset.route {
305                    info!("serving {} on alias {}", asset.route, from);
306
307                    router = router.route(from, get(handler));
308                }
309            }
310        }
311
312        router
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use crate::{self as memory_serve, Asset, CacheControl, MemoryServe};
319    use axum::{
320        Router,
321        body::Body,
322        http::{
323            self, HeaderMap, HeaderName, HeaderValue, Request, StatusCode,
324            header::{self, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH},
325        },
326    };
327    use memory_serve_macros::load_assets;
328    use tower::ServiceExt;
329
330    async fn get(
331        router: Router,
332        path: &str,
333        key: &str,
334        value: &str,
335    ) -> (StatusCode, HeaderMap<HeaderValue>) {
336        let response = router
337            .oneshot(
338                Request::builder()
339                    .method(http::Method::GET)
340                    .header(key, value)
341                    .uri(path)
342                    .body(Body::empty())
343                    .unwrap(),
344            )
345            .await
346            .unwrap();
347
348        (response.status(), response.headers().to_owned())
349    }
350
351    fn get_header<'s>(headers: &'s HeaderMap, name: &HeaderName) -> &'s str {
352        headers.get(name).unwrap().to_str().unwrap()
353    }
354
355    #[tokio::test]
356    async fn test_load_assets() {
357        let assets: &[Asset] = load_assets!("../static");
358        let routes: Vec<&str> = assets.iter().map(|a| a.route).collect();
359        let content_types: Vec<&str> = assets.iter().map(|a| a.content_type).collect();
360        let etags: Vec<&str> = assets.iter().map(|a| a.etag).collect();
361
362        assert_eq!(
363            routes,
364            [
365                "/about.html",
366                "/assets/icon.jpg",
367                "/assets/index.css",
368                "/assets/index.js",
369                "/assets/stars.svg",
370                "/blog/index.html",
371                "/index.html"
372            ]
373        );
374        assert_eq!(
375            content_types,
376            [
377                "text/html",
378                "image/jpeg",
379                "text/css",
380                "text/javascript",
381                "image/svg+xml",
382                "text/html",
383                "text/html"
384            ]
385        );
386        if cfg!(debug_assertions) && !cfg!(feature = "force-embed") {
387            assert_eq!(etags, ["", "", "", "", "", "", ""]);
388        } else {
389            assert_eq!(
390                etags,
391                [
392                    "56a0dcb83ec56b6c967966a1c06c7b1392e261069d0844aa4e910ca5c1e8cf58",
393                    "e64f4683bf82d854df40b7246666f6f0816666ad8cd886a8e159535896eb03d6",
394                    "ec4edeea111c854901385011f403e1259e3f1ba016dcceabb6d566316be3677b",
395                    "86a7fdfd19700843e5f7344a63d27e0b729c2554c8572903ceee71f5658d2ecf",
396                    "bd9dccc152de48cb7bedc35b9748ceeade492f6f904710f9c5d480bd6299cc7d",
397                    "89e9873a8e49f962fe83ad2bfe6ac9b21ef7c1b4040b99c34eb783dccbadebc5",
398                    "0639dc8aac157b58c74f65bbb026b2fd42bc81d9a0a64141df456fa23c214537"
399                ]
400            );
401        }
402    }
403
404    #[tokio::test]
405    async fn if_none_match_handling() {
406        let memory_router = MemoryServe::new(load_assets!("../static")).into_router();
407        let (code, headers) =
408            get(memory_router.clone(), "/index.html", "accept", "text/html").await;
409        let etag: &str = headers.get(header::ETAG).unwrap().to_str().unwrap();
410
411        assert_eq!(code, 200);
412        assert_eq!(
413            etag,
414            "0639dc8aac157b58c74f65bbb026b2fd42bc81d9a0a64141df456fa23c214537"
415        );
416
417        let (code, headers) = get(memory_router, "/index.html", "If-None-Match", etag).await;
418        let length = get_header(&headers, &CONTENT_LENGTH);
419
420        assert_eq!(code, 304);
421        assert_eq!(length.parse::<i32>().unwrap(), 0);
422    }
423
424    #[tokio::test]
425    async fn brotli_compression() {
426        let memory_router = MemoryServe::new(load_assets!("../static"))
427            .enable_brotli(true)
428            .into_router();
429        let (code, headers) = get(
430            memory_router.clone(),
431            "/index.html",
432            "accept-encoding",
433            "br",
434        )
435        .await;
436        let encoding = get_header(&headers, &CONTENT_ENCODING);
437        let length = get_header(&headers, &CONTENT_LENGTH);
438
439        assert_eq!(code, 200);
440        assert_eq!(encoding, "br");
441        assert_eq!(length.parse::<i32>().unwrap(), 178);
442
443        // check disable compression
444        let memory_router = MemoryServe::new(load_assets!("../static"))
445            .enable_brotli(false)
446            .into_router();
447        let (code, headers) = get(
448            memory_router.clone(),
449            "/index.html",
450            "accept-encoding",
451            "br",
452        )
453        .await;
454        let length: &str = get_header(&headers, &CONTENT_LENGTH);
455
456        assert_eq!(code, 200);
457        assert_eq!(length.parse::<i32>().unwrap(), 437);
458    }
459
460    #[tokio::test]
461    async fn gzip_compression() {
462        let memory_router = MemoryServe::new(load_assets!("../static"))
463            .enable_gzip(true)
464            .into_router();
465        let (code, headers) = get(
466            memory_router.clone(),
467            "/index.html",
468            "accept-encoding",
469            "gzip",
470        )
471        .await;
472
473        let encoding = get_header(&headers, &CONTENT_ENCODING);
474        let length = get_header(&headers, &CONTENT_LENGTH);
475
476        assert_eq!(code, 200);
477        assert_eq!(encoding, "gzip");
478        assert_eq!(length.parse::<i32>().unwrap(), 274);
479
480        // check disable compression
481        let memory_router = MemoryServe::new(load_assets!("../static"))
482            .enable_gzip(false)
483            .into_router();
484        let (code, headers) = get(
485            memory_router.clone(),
486            "/index.html",
487            "accept-encoding",
488            "gzip",
489        )
490        .await;
491        let length: &str = get_header(&headers, &CONTENT_LENGTH);
492
493        assert_eq!(code, 200);
494        assert_eq!(length.parse::<i32>().unwrap(), 437);
495    }
496
497    #[tokio::test]
498    async fn index_file() {
499        let memory_router = MemoryServe::new(load_assets!("../static"))
500            .index_file(None)
501            .into_router();
502
503        let (code, _) = get(memory_router.clone(), "/", "accept", "*").await;
504        assert_eq!(code, 404);
505
506        let memory_router = MemoryServe::new(load_assets!("../static"))
507            .index_file(Some("/index.html"))
508            .into_router();
509
510        let (code, _) = get(memory_router.clone(), "/", "accept", "*").await;
511        assert_eq!(code, 200);
512    }
513
514    #[tokio::test]
515    async fn index_file_on_subdirs() {
516        let memory_router = MemoryServe::new(load_assets!("../static"))
517            .index_file(Some("/index.html"))
518            .index_on_subdirectories(false)
519            .into_router();
520
521        let (code, _) = get(memory_router.clone(), "/blog", "accept", "*").await;
522        assert_eq!(code, 404);
523
524        let memory_router = MemoryServe::new(load_assets!("../static"))
525            .index_file(Some("/index.html"))
526            .index_on_subdirectories(true)
527            .into_router();
528
529        let (code, _) = get(memory_router.clone(), "/blog", "accept", "*").await;
530        assert_eq!(code, 200);
531    }
532
533    #[tokio::test]
534    async fn clean_url() {
535        let memory_router = MemoryServe::new(load_assets!("../static"))
536            .enable_clean_url(true)
537            .into_router();
538
539        let (code, _) = get(memory_router.clone(), "/about.html", "accept", "*").await;
540        assert_eq!(code, 404);
541
542        let (code, _) = get(memory_router.clone(), "/about", "accept", "*").await;
543        assert_eq!(code, 200);
544    }
545
546    #[tokio::test]
547    async fn fallback() {
548        let memory_router = MemoryServe::new(load_assets!("../static")).into_router();
549        let (code, _) = get(memory_router.clone(), "/foobar", "accept", "*").await;
550        assert_eq!(code, 404);
551
552        let memory_router = MemoryServe::new(load_assets!("../static"))
553            .fallback(Some("/index.html"))
554            .into_router();
555        let (code, headers) = get(memory_router.clone(), "/foobar", "accept", "*").await;
556        let length = get_header(&headers, &CONTENT_LENGTH);
557        assert_eq!(code, 404);
558        assert_eq!(length.parse::<i32>().unwrap(), 437);
559
560        let memory_router = MemoryServe::new(load_assets!("../static"))
561            .fallback(Some("/index.html"))
562            .fallback_status(StatusCode::OK)
563            .into_router();
564        let (code, headers) = get(memory_router.clone(), "/foobar", "accept", "*").await;
565        let length = get_header(&headers, &CONTENT_LENGTH);
566        assert_eq!(code, 200);
567        assert_eq!(length.parse::<i32>().unwrap(), 437);
568    }
569
570    #[tokio::test]
571    async fn cache_control() {
572        async fn check_cache_control(cache_control: CacheControl, expected: &str) {
573            let memory_router = MemoryServe::new(load_assets!("../static"))
574                .cache_control(cache_control)
575                .into_router();
576
577            let (code, headers) =
578                get(memory_router.clone(), "/assets/icon.jpg", "accept", "*").await;
579
580            let cache_control = get_header(&headers, &CACHE_CONTROL);
581            assert_eq!(code, 200);
582            assert_eq!(cache_control, expected);
583        }
584
585        check_cache_control(
586            CacheControl::NoCache,
587            CacheControl::NoCache.as_header().1.to_str().unwrap(),
588        )
589        .await;
590        check_cache_control(
591            CacheControl::Short,
592            CacheControl::Short.as_header().1.to_str().unwrap(),
593        )
594        .await;
595        check_cache_control(
596            CacheControl::Medium,
597            CacheControl::Medium.as_header().1.to_str().unwrap(),
598        )
599        .await;
600        check_cache_control(
601            CacheControl::Long,
602            CacheControl::Long.as_header().1.to_str().unwrap(),
603        )
604        .await;
605
606        async fn check_html_cache_control(cache_control: CacheControl, expected: &str) {
607            let memory_router = MemoryServe::new(load_assets!("../static"))
608                .html_cache_control(cache_control)
609                .into_router();
610
611            let (code, headers) = get(memory_router.clone(), "/index.html", "accept", "*").await;
612            let cache_control = get_header(&headers, &CACHE_CONTROL);
613            assert_eq!(code, 200);
614            assert_eq!(cache_control, expected);
615        }
616
617        check_html_cache_control(
618            CacheControl::NoCache,
619            CacheControl::NoCache.as_header().1.to_str().unwrap(),
620        )
621        .await;
622        check_html_cache_control(
623            CacheControl::Short,
624            CacheControl::Short.as_header().1.to_str().unwrap(),
625        )
626        .await;
627        check_html_cache_control(
628            CacheControl::Medium,
629            CacheControl::Medium.as_header().1.to_str().unwrap(),
630        )
631        .await;
632        check_html_cache_control(
633            CacheControl::Long,
634            CacheControl::Long.as_header().1.to_str().unwrap(),
635        )
636        .await;
637    }
638
639    #[tokio::test]
640    async fn aliases() {
641        let memory_router = MemoryServe::new(load_assets!("../static"))
642            .add_alias("/foobar", "/index.html")
643            .add_alias("/baz", "/index.html")
644            .into_router();
645        let (code, _) = get(memory_router.clone(), "/foobar", "accept", "*").await;
646        assert_eq!(code, 200);
647
648        let (code, _) = get(memory_router.clone(), "/baz", "accept", "*").await;
649        assert_eq!(code, 200);
650
651        let (code, _) = get(memory_router.clone(), "/index.html", "accept", "*").await;
652        assert_eq!(code, 200);
653
654        let (code, _) = get(memory_router.clone(), "/barfoo", "accept", "*").await;
655        assert_eq!(code, 404);
656    }
657}