axum_embed/
lib.rs

1//! `axum_embed` is a library that provides a service for serving embedded files using the `axum` web framework.
2//!
3//! This library uses the `rust_embed` crate to embedded files into the binary at compile time, and the `axum` crate to serve these files over HTTP.
4//!
5//! # Features
6//! - Serve embedded files over HTTP
7//! - Customizable 404, fallback, and index files
8//! - Response compressed files if the client supports it and the compressed file exists
9//! - Response 304 if the client has the same file (based on ETag)
10//! - Redirect to the directory if the client requests a directory without a trailing slash
11//!
12//! # Example
13//! ```ignore
14//! # use rust_embed::RustEmbed;
15//! # use axum_embed::ServeEmbed;
16//! # use tokio::net::TcpListener;
17//! #
18//! #[derive(RustEmbed, Clone)]
19//! #[folder = "examples/assets/"]
20//! struct Assets;
21//!
22//! # #[tokio::main]
23//! # async fn main() -> anyhow::Result<()> {
24//! let listener = TcpListener::bind("127.0.0.1:8080").await?;
25//! let serve_assets = ServeEmbed::<Assets>::new();
26//! let app = axum::Router::new().nest_service("/", serve_assets);
27//! axum::serve(listener, app).await?;
28//!
29//! # Ok(())
30//! # }
31//! ```
32//!
33//! # Usage
34//!
35//! Please see the [examples](https://github.com/informationsea/axum-embed/tree/main/examples) directory for a working example.
36//!
37//! ## Serve compressed file
38//!
39//! The `axum_embed` library has the capability to serve compressed files, given that the client supports it and the compressed file is available.
40//! The compression methods supported include `br` (Brotli), `gzip`, and `deflate`.
41//! If the client supports multiple compression methods, `axum_embed` will select the first one listed in the `Accept-Encoding` header. Please note that the weight of encoding is not considered in this selection.
42//! In the absence of client support for any compression methods, `axum_embed` will serve the file in its uncompressed form.
43//! If a file with the extension `.br` (for Brotli), `.gz` (for GZip), or `.zz` (for Deflate) is available, `axum_embed` will serve the file in its compressed form.
44//! An uncompressed file is must be available for the compressed file to be served.
45use std::{borrow::Cow, convert::Infallible, future::Future, pin::Pin, sync::Arc, task::Poll};
46
47use axum_core::body::Body;
48use axum_core::extract::Request;
49use axum_core::response::Response;
50use chrono::{DateTime, Utc};
51use http::StatusCode;
52use rust_embed::RustEmbed;
53use tower_service::Service;
54
55#[derive(Clone, RustEmbed)]
56#[folder = "src/assets"]
57struct DefaultFallback;
58
59/// [`FallbackBehavior`] is an enumeration representing different behaviors that a server might take when a requested resource is not found.
60#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
61pub enum FallbackBehavior {
62    /// The server responds the fallback resource with 404 status code when the resource was not found.
63    NotFound,
64    /// The server redirects the user to a different resource when the resource was not found.
65    Redirect,
66    /// The server responds the fallback resource with 200 status code when the resource was not found.
67    Ok,
68}
69
70/// [`ServeEmbed`] is a struct that represents a service for serving embedded files.
71///
72/// # Parameters
73/// - `E`: A type that implements the [`RustEmbed`] and `Clone` trait. This type represents the embedded files.
74///
75/// # Example
76/// ```ignore
77/// # use rust_embed::RustEmbed;
78/// # use axum_embed::ServeEmbed;
79/// # use tokio::net::TcpListener;
80/// #
81/// #[derive(RustEmbed, Clone)]
82/// #[folder = "examples/assets/"]
83/// struct Assets;
84///
85/// # #[tokio::main]
86/// # async fn main() -> anyhow::Result<()> {
87/// let listener = TcpListener::bind("127.0.0.1:8080").await?;
88/// let serve_assets = ServeEmbed::<Assets>::new();
89/// let app = axum::Router::new().nest_service("/", serve_assets);
90/// axum::serve(listener, app).await?;
91///
92/// # Ok(())
93/// # }
94/// ```
95#[derive(Debug, Clone)]
96pub struct ServeEmbed<E: RustEmbed + Clone> {
97    _phantom: std::marker::PhantomData<E>,
98    fallback_file: Arc<Option<String>>,
99    fallback_behavior: FallbackBehavior,
100    index_file: Arc<Option<String>>,
101}
102
103impl<E: RustEmbed + Clone> ServeEmbed<E> {
104    /// Constructs a new `ServeEmbed` instance with default parameters.
105    ///
106    /// This function calls `with_parameters` internally with `None` for `fallback_file`, [`FallbackBehavior::NotFound`] for `fallback_behavior`, and `"index.html"` for `index_file`.
107    ///
108    /// # Returns
109    /// A new `ServeEmbed` instance with default parameters.
110    pub fn new() -> Self {
111        Self::with_parameters(
112            None,
113            FallbackBehavior::NotFound,
114            Some("index.html".to_owned()),
115        )
116    }
117
118    /// Constructs a new `ServeEmbed` instance with the provided parameters.
119    ///
120    /// # Parameters
121    /// - `fallback_file`: The path of the file to serve when a requested file is not found. If `None`, a default 404 response is served.
122    /// - `fallback_behavior`: The behavior of the server when a requested file is not found. Please see [`FallbackBehavior`] for more information.
123    /// - `index_file`: The name of the file to serve when a directory is accessed. If `None`, a 404 response is served for directory.
124    ///
125    /// # Returns
126    /// A new `ServeEmbed` instance.
127    pub fn with_parameters(
128        fallback_file: Option<String>,
129        fallback_behavior: FallbackBehavior,
130        index_file: Option<String>,
131    ) -> Self {
132        Self {
133            _phantom: std::marker::PhantomData,
134            fallback_file: Arc::new(fallback_file),
135            fallback_behavior,
136            index_file: Arc::new(index_file),
137        }
138    }
139}
140
141impl<E: RustEmbed + Clone, T: Send + 'static> Service<http::request::Request<T>> for ServeEmbed<E> {
142    type Response = Response;
143    type Error = Infallible;
144    type Future = ServeFuture<E, T>;
145
146    fn poll_ready(
147        &mut self,
148        _cx: &mut std::task::Context<'_>,
149    ) -> std::task::Poll<Result<(), Self::Error>> {
150        Poll::Ready(Ok(()))
151    }
152
153    fn call(&mut self, req: http::request::Request<T>) -> Self::Future {
154        ServeFuture {
155            _phantom: std::marker::PhantomData,
156            fallback_behavior: self.fallback_behavior,
157            fallback_file: self.fallback_file.clone(),
158            index_file: self.index_file.clone(),
159            request: req,
160        }
161    }
162}
163
164#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
165enum CompressionMethod {
166    Identity,
167    Brotli,
168    Gzip,
169    Zlib,
170}
171
172impl CompressionMethod {
173    fn extension(self) -> &'static str {
174        match self {
175            Self::Identity => "",
176            Self::Brotli => ".br",
177            Self::Gzip => ".gz",
178            Self::Zlib => ".zz",
179        }
180    }
181}
182
183fn from_acceptable_encoding(acceptable_encoding: Option<&str>) -> Vec<CompressionMethod> {
184    let mut compression_methods = Vec::new();
185
186    let mut identity_found = false;
187    for acceptable_encoding in acceptable_encoding.unwrap_or("").split(',') {
188        let acceptable_encoding = acceptable_encoding.trim().split(';').next().unwrap();
189        if acceptable_encoding == "br" {
190            compression_methods.push(CompressionMethod::Brotli);
191        } else if acceptable_encoding == "gzip" {
192            compression_methods.push(CompressionMethod::Gzip);
193        } else if acceptable_encoding == "deflate" {
194            compression_methods.push(CompressionMethod::Zlib);
195        } else if acceptable_encoding == "identity" {
196            compression_methods.push(CompressionMethod::Identity);
197            identity_found = true;
198        }
199    }
200
201    if !identity_found {
202        compression_methods.push(CompressionMethod::Identity);
203    }
204
205    compression_methods
206}
207
208struct GetFileResult<'a> {
209    path: Cow<'a, str>,
210    file: Option<rust_embed::EmbeddedFile>,
211    should_redirect: Option<String>,
212    compression_method: CompressionMethod,
213    is_fallback: bool,
214}
215
216/// `ServeFuture` is a future that represents a service for serving embedded files.
217/// This future is created by `ServeEmbed`.
218/// This future is not intended to be used directly.
219#[derive(Debug, Clone)]
220pub struct ServeFuture<E: RustEmbed, T> {
221    _phantom: std::marker::PhantomData<E>,
222    fallback_behavior: FallbackBehavior,
223    fallback_file: Arc<Option<String>>,
224    index_file: Arc<Option<String>>,
225    request: Request<T>,
226}
227
228impl<E: RustEmbed, T> ServeFuture<E, T> {
229    /// Attempts to get a file from the embedded files based on the provided path and acceptable encodings.
230    ///
231    /// # Parameters
232    /// - `path`: The path of the requested file. This should be a relative path from the root of the embedded files.
233    /// - `acceptable_encoding`: A list of compression methods that the client can accept. This is typically obtained from the `Accept-Encoding` header of the HTTP request.
234    ///
235    /// # Returns
236    /// A `GetFileResult` instance. If a file is found that matches the path and one of the acceptable encodings, it is included in the result. Otherwise, the result includes the path and `None` for the file.
237    fn get_file<'a>(
238        &self,
239        path: &'a str,
240        acceptable_encoding: &[CompressionMethod],
241    ) -> GetFileResult<'a> {
242        let mut path_candidate = Cow::Borrowed(path.trim_start_matches('/'));
243
244        if path_candidate == "" {
245            if let Some(index_file) = self.index_file.as_ref() {
246                path_candidate = Cow::Owned(index_file.to_string());
247            }
248        } else if path_candidate.ends_with('/') {
249            if let Some(index_file) = self.index_file.as_ref().as_ref() {
250                let new_path_candidate = format!("{}{}", path_candidate, index_file);
251                if E::get(&new_path_candidate).is_some() {
252                    path_candidate = Cow::Owned(new_path_candidate);
253                }
254            }
255        } else {
256            if let Some(index_file) = self.index_file.as_ref().as_ref() {
257                let new_path_candidate = format!("{}/{}", path_candidate, index_file);
258                if E::get(&new_path_candidate).is_some() {
259                    return GetFileResult {
260                        path: Cow::Owned(new_path_candidate),
261                        file: None,
262                        should_redirect: Some(format!("/{}/", path_candidate)),
263                        compression_method: CompressionMethod::Identity,
264                        is_fallback: false,
265                    };
266                }
267            }
268        }
269
270        let mut file = E::get(&path_candidate);
271        let mut compressed_method = CompressionMethod::Identity;
272
273        if file.is_some() {
274            for one_method in acceptable_encoding {
275                if let Some(x) = E::get(&format!("{}{}", path_candidate, one_method.extension())) {
276                    file = Some(x);
277                    compressed_method = *one_method;
278                    break;
279                }
280            }
281        }
282
283        GetFileResult {
284            path: path_candidate,
285            file,
286            should_redirect: None,
287            compression_method: compressed_method,
288            is_fallback: false,
289        }
290    }
291
292    fn get_file_with_fallback<'a, 'b: 'a>(
293        &'b self,
294        path: &'a str,
295        acceptable_encoding: &[CompressionMethod],
296    ) -> GetFileResult<'a> {
297        let first_try = self.get_file(path, acceptable_encoding);
298        if first_try.file.is_some() || first_try.should_redirect.is_some() {
299            return first_try;
300        }
301        if let Some(fallback_file) = self.fallback_file.as_ref().as_ref() {
302            if fallback_file != path && self.fallback_behavior == FallbackBehavior::Redirect {
303                return GetFileResult {
304                    path: Cow::Borrowed(path),
305                    file: None,
306                    should_redirect: Some(format!("/{}", fallback_file)),
307                    compression_method: CompressionMethod::Identity,
308                    is_fallback: true,
309                };
310            }
311            let mut fallback_try = self.get_file(fallback_file, acceptable_encoding);
312            fallback_try.is_fallback = true;
313            if fallback_try.file.is_some() {
314                return fallback_try;
315            }
316        }
317        GetFileResult {
318            path: Cow::Borrowed("404.html"),
319            file: DefaultFallback::get("404.html"),
320            should_redirect: None,
321            compression_method: CompressionMethod::Identity,
322            is_fallback: true,
323        }
324    }
325}
326
327impl<E: RustEmbed, T> Future for ServeFuture<E, T> {
328    type Output = Result<Response<Body>, Infallible>;
329
330    fn poll(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
331        // Accept only GET and HEAD method
332        if self.request.method() != http::Method::GET && self.request.method() != http::Method::HEAD
333        {
334            return Poll::Ready(Ok(Response::builder()
335                .status(StatusCode::METHOD_NOT_ALLOWED)
336                .header(http::header::CONTENT_TYPE, "text/plain")
337                .body(Body::from("Method not allowed"))
338                .unwrap()));
339        }
340
341        // get embedded file for the requested path
342        let (path, file, compression_method, is_fallback) = match self.get_file_with_fallback(
343            self.request.uri().path(),
344            &from_acceptable_encoding(
345                self.request
346                    .headers()
347                    .get(http::header::ACCEPT_ENCODING)
348                    .map(|x| x.to_str().ok())
349                    .flatten(),
350            ),
351        ) {
352            // if the file is found, return it
353            GetFileResult {
354                path,
355                file: Some(file),
356                should_redirect: None,
357                compression_method,
358                is_fallback,
359            } => (path, file, compression_method, is_fallback),
360            // if the path is a directory and the client does not have a trailing slash, redirect to the directory with a trailing slash
361            GetFileResult {
362                path: _,
363                file: _,
364                should_redirect: Some(should_redirect),
365                compression_method: _,
366                is_fallback,
367            } => {
368                return Poll::Ready(Ok(Response::builder()
369                    .status(if is_fallback {
370                        StatusCode::TEMPORARY_REDIRECT
371                    } else {
372                        StatusCode::MOVED_PERMANENTLY
373                    })
374                    .header(http::header::LOCATION, should_redirect)
375                    .header(http::header::CONTENT_TYPE, "text/plain")
376                    .body(if is_fallback {
377                        Body::from("Temporary redirect")
378                    } else {
379                        Body::from("Moved permanently")
380                    })
381                    .unwrap()));
382            }
383            // if the file is not found, return 404
384            _ => {
385                unreachable!();
386            }
387        };
388
389        // If the client has the same file, return 304
390        if !is_fallback
391            && self
392                .request
393                .headers()
394                .get(http::header::IF_NONE_MATCH)
395                .and_then(|value| {
396                    value
397                        .to_str()
398                        .ok()
399                        .and_then(|value| Some(value.trim_matches('"')))
400                })
401                == Some(hash_to_string(&file.metadata.sha256_hash()).as_str())
402        {
403            return Poll::Ready(Ok(Response::builder()
404                .status(StatusCode::NOT_MODIFIED)
405                .body(Body::empty())
406                .unwrap()));
407        }
408
409        // build response and set headers
410        let mut response_builder = Response::builder()
411            .header(
412                http::header::CONTENT_TYPE,
413                mime_guess::from_path(path.as_ref())
414                    .first_or_octet_stream()
415                    .to_string(),
416            )
417            .header(
418                http::header::ETAG,
419                hash_to_string(&file.metadata.sha256_hash()),
420            );
421
422        match compression_method {
423            CompressionMethod::Identity => {}
424            CompressionMethod::Brotli => {
425                response_builder = response_builder.header(http::header::CONTENT_ENCODING, "br");
426            }
427            CompressionMethod::Gzip => {
428                response_builder = response_builder.header(http::header::CONTENT_ENCODING, "gzip");
429            }
430            CompressionMethod::Zlib => {
431                response_builder =
432                    response_builder.header(http::header::CONTENT_ENCODING, "deflate");
433            }
434        }
435
436        if let Some(last_modified) = file.metadata.last_modified() {
437            response_builder =
438                response_builder.header(http::header::LAST_MODIFIED, date_to_string(last_modified));
439        }
440
441        if is_fallback && self.fallback_behavior != FallbackBehavior::Ok {
442            response_builder = response_builder.status(StatusCode::NOT_FOUND);
443        } else {
444            response_builder = response_builder.status(StatusCode::OK);
445        }
446
447        Poll::Ready(Ok(response_builder
448            .body(file.data.to_owned().into())
449            .unwrap()))
450    }
451}
452
453fn hash_to_string(hash: &[u8; 32]) -> String {
454    let mut s = String::with_capacity(64);
455    for byte in hash {
456        s.push_str(&format!("{:02x}", byte));
457    }
458    s
459}
460
461fn date_to_string(date: u64) -> String {
462    DateTime::<Utc>::from_timestamp(date as i64, 0)
463        .unwrap()
464        .format("%a, %d %b %Y %H:%M:%S GMT")
465        .to_string()
466}
467
468#[cfg(test)]
469mod test;