axum_embeddy/
lib.rs

1//! [![Crates.io](https://img.shields.io/crates/v/axum-embeddy)](https://crates.io/crates/axum-embeddy)
2//! [![Crates.io](https://img.shields.io/crates/l/axum-embeddy)](https://crates.io/crates/axum-embeddy)
3//! [![docs.rs](https://img.shields.io/docsrs/axum-embeddy)](https://docs.rs/axum-embeddy/)
4//! `axum_embeddy` is a library that provides a service for serving embedded files using the `axum` web framework.
5//!
6//! This library uses the `rust_embed` crate to embed files into the binary at compile time, and the `axum` crate to serve these files over HTTP.
7//!
8//! # Features
9//! - Serve embedded files over HTTP
10//! - Customizable 404, fallback, and index files
11//! - Response compressed files if the client supports it and the compressed file exists
12//! - Response 304 if the client has the same file (based on ETag)
13//! - Redirect to the directory if the client requests a directory without a trailing slash
14//!
15//! # Example
16//! ```ignore
17//! # use rust_embed::RustEmbed;
18//! # use axum_embeddy::ServeEmbed;
19//! # use tokio::net::TcpListener;
20//! #
21//! #[derive(RustEmbed, Clone)]
22//! #[folder = "examples/assets/"]
23//! struct Assets;
24//!
25//! # #[tokio::main]
26//! # async fn main() -> anyhow::Result<()> {
27//! let listener = TcpListener::bind("127.0.0.1:8080").await?;
28//! let serve_assets = ServeEmbed::<Assets>::new();
29//! let app = axum::Router::new().nest_service("/", serve_assets);
30//! axum::serve(listener, app).await?;
31//!
32//! # Ok(())
33//! # }
34//! ```
35//!
36//! # Usage
37//!
38//! Please see the [examples](https://github.com/informationsea/axum-embed/tree/main/examples) directory for a working example.
39//!
40//! ## Serve compressed file
41//!
42//! The `axum_embeddy` library has the capability to serve compressed files, given that the client supports it and the compressed file is available.
43//! The compression methods supported include `br` (Brotli), `gzip`, and `deflate`.
44//! If the client supports multiple compression methods, `axum_embeddy` will select the first one listed in the `Accept-Encoding` header. Please note that the weight of encoding is not considered in this selection.
45//! In the absence of client support for any compression methods, `axum_embeddy` will serve the file in its uncompressed form.
46//! If a file with the extension `.br` (for Brotli), `.gz` (for GZip), or `.zz` (for Deflate) is available, `axum_embeddy` will serve the file in its compressed form.
47//! An uncompressed file is must be available for the compressed file to be served.
48use std::{borrow::Cow, convert::Infallible, future::Future, pin::Pin, sync::Arc, task::Poll};
49
50use bytes::Bytes;
51use http::{Request, Response, StatusCode};
52use http_body_util::Full;
53use rust_embed::RustEmbed;
54use time::{OffsetDateTime, format_description};
55use tower_service::Service;
56
57#[derive(Clone, RustEmbed)]
58#[folder = "src/assets"]
59struct DefaultFallback;
60
61/// [`FallbackBehavior`] is an enumeration representing different behaviors that a server might take when a requested resource is not found.
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
63pub enum FallbackBehavior {
64    /// The server responds the fallback resource with 404 status code when the resource was not found.
65    NotFound,
66    /// The server redirects the user to a different resource when the resource was not found.
67    Redirect,
68    /// The server responds the fallback resource with 200 status code when the resource was not found.
69    Ok,
70}
71
72/// [`ServeEmbed`] is a struct that represents a service for serving embedded files.
73///
74/// # Parameters
75/// - `E`: A type that implements the [`RustEmbed`] and `Clone` trait. This type represents the embedded files.
76///
77/// # Example
78/// ```ignore
79/// # use rust_embed::RustEmbed;
80/// # use axum_embeddy::ServeEmbed;
81/// # use tokio::net::TcpListener;
82/// #
83/// #[derive(RustEmbed, Clone)]
84/// #[folder = "examples/assets/"]
85/// struct Assets;
86///
87/// # #[tokio::main]
88/// # async fn main() -> anyhow::Result<()> {
89/// let listener = TcpListener::bind("127.0.0.1:8080").await?;
90/// let serve_assets = ServeEmbed::<Assets>::new();
91/// let app = axum::Router::new().nest_service("/", serve_assets);
92/// axum::serve(listener, app).await?;
93///
94/// # Ok(())
95/// # }
96/// ```
97#[derive(Debug, Clone)]
98pub struct ServeEmbed<E: RustEmbed + Clone> {
99    _phantom: std::marker::PhantomData<E>,
100    fallback_file: Arc<Option<String>>,
101    fallback_behavior: FallbackBehavior,
102    index_file: Arc<Option<String>>,
103}
104
105impl<E: RustEmbed + Clone> ServeEmbed<E> {
106    /// Constructs a new `ServeEmbed` instance with default parameters.
107    ///
108    /// This function calls `with_parameters` internally with `None` for `fallback_file`, [`FallbackBehavior::NotFound`] for `fallback_behavior`, and `"index.html"` for `index_file`.
109    ///
110    /// # Returns
111    /// A new `ServeEmbed` instance with default parameters.
112    pub fn new() -> Self {
113        Self::with_parameters(
114            None,
115            FallbackBehavior::NotFound,
116            Some("index.html".to_owned()),
117        )
118    }
119
120    /// Constructs a new `ServeEmbed` instance with the provided parameters.
121    ///
122    /// # Parameters
123    /// - `fallback_file`: The path of the file to serve when a requested file is not found. If `None`, a default 404 response is served.
124    /// - `fallback_behavior`: The behavior of the server when a requested file is not found. Please see [`FallbackBehavior`] for more information.
125    /// - `index_file`: The name of the file to serve when a directory is accessed. If `None`, a 404 response is served for directory.
126    ///
127    /// # Returns
128    /// A new `ServeEmbed` instance.
129    pub fn with_parameters(
130        fallback_file: Option<String>,
131        fallback_behavior: FallbackBehavior,
132        index_file: Option<String>,
133    ) -> Self {
134        Self {
135            _phantom: std::marker::PhantomData,
136            fallback_file: Arc::new(fallback_file),
137            fallback_behavior,
138            index_file: Arc::new(index_file),
139        }
140    }
141}
142
143impl<E: RustEmbed + Clone> Default for ServeEmbed<E> {
144    fn default() -> Self {
145        Self::new()
146    }
147}
148
149impl<E: RustEmbed + Clone, T: Send + 'static> Service<http::request::Request<T>> for ServeEmbed<E> {
150    type Response = http::Response<Full<Bytes>>;
151    type Error = Infallible;
152    type Future = ServeFuture<E, T>;
153
154    fn poll_ready(
155        &mut self,
156        _cx: &mut std::task::Context<'_>,
157    ) -> std::task::Poll<Result<(), Self::Error>> {
158        Poll::Ready(Ok(()))
159    }
160
161    fn call(&mut self, req: http::request::Request<T>) -> Self::Future {
162        ServeFuture {
163            _phantom: std::marker::PhantomData,
164            fallback_behavior: self.fallback_behavior,
165            fallback_file: self.fallback_file.clone(),
166            index_file: self.index_file.clone(),
167            request: req,
168        }
169    }
170}
171
172#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
173enum CompressionMethod {
174    Identity,
175    Brotli,
176    Gzip,
177    Zlib,
178}
179
180impl CompressionMethod {
181    const fn extension(self) -> &'static str {
182        match self {
183            Self::Identity => "",
184            Self::Brotli => ".br",
185            Self::Gzip => ".gz",
186            Self::Zlib => ".zz",
187        }
188    }
189}
190
191fn from_acceptable_encoding(acceptable_encoding: Option<&str>) -> Vec<CompressionMethod> {
192    let mut compression_methods = Vec::new();
193
194    let mut identity_found = false;
195    for acceptable_encoding in acceptable_encoding.unwrap_or("").split(',') {
196        let acceptable_encoding = acceptable_encoding.trim().split(';').next().unwrap();
197        if acceptable_encoding == "br" {
198            compression_methods.push(CompressionMethod::Brotli);
199        } else if acceptable_encoding == "gzip" {
200            compression_methods.push(CompressionMethod::Gzip);
201        } else if acceptable_encoding == "deflate" {
202            compression_methods.push(CompressionMethod::Zlib);
203        } else if acceptable_encoding == "identity" {
204            compression_methods.push(CompressionMethod::Identity);
205            identity_found = true;
206        }
207    }
208
209    if !identity_found {
210        compression_methods.push(CompressionMethod::Identity);
211    }
212
213    compression_methods
214}
215
216fn cow_to_bytes(cow: Cow<'static, [u8]>) -> Bytes {
217    match cow {
218        Cow::Borrowed(x) => Bytes::from(x),
219        Cow::Owned(x) => Bytes::from(x),
220    }
221}
222
223struct GetFileResult<'a> {
224    path: Cow<'a, str>,
225    file: Option<rust_embed::EmbeddedFile>,
226    should_redirect: Option<String>,
227    compression_method: CompressionMethod,
228    is_fallback: bool,
229}
230
231/// `ServeFuture` is a future that represents a service for serving embedded files.
232/// This future is created by `ServeEmbed`.
233/// This future is not intended to be used directly.
234#[derive(Debug, Clone)]
235pub struct ServeFuture<E: RustEmbed, T> {
236    _phantom: std::marker::PhantomData<E>,
237    fallback_behavior: FallbackBehavior,
238    fallback_file: Arc<Option<String>>,
239    index_file: Arc<Option<String>>,
240    request: Request<T>,
241}
242
243impl<E: RustEmbed, T> ServeFuture<E, T> {
244    /// Attempts to get a file from the embedded files based on the provided path and acceptable encodings.
245    ///
246    /// # Parameters
247    /// - `path`: The path of the requested file. This should be a relative path from the root of the embedded files.
248    /// - `acceptable_encoding`: A list of compression methods that the client can accept. This is typically obtained from the `Accept-Encoding` header of the HTTP request.
249    ///
250    /// # Returns
251    /// A `GetFileResult` instance. If a file is found that matches the path and one of the acceptable encodings, it is included in the result. Otherwise, the result includes the path and `None` for the file.
252    fn get_file<'a>(
253        &self,
254        path: &'a str,
255        acceptable_encoding: &[CompressionMethod],
256    ) -> GetFileResult<'a> {
257        let mut path_candidate = Cow::Borrowed(path.trim_start_matches('/'));
258
259        if path_candidate.is_empty() {
260            if let Some(index_file) = self.index_file.as_ref() {
261                path_candidate = Cow::Owned(index_file.to_string());
262            }
263        } else if path_candidate.ends_with('/') {
264            if let Some(index_file) = self.index_file.as_ref().as_ref() {
265                let new_path_candidate = format!("{}{}", path_candidate, index_file);
266                if E::get(&new_path_candidate).is_some() {
267                    path_candidate = Cow::Owned(new_path_candidate);
268                }
269            }
270        } else if let Some(index_file) = self.index_file.as_ref().as_ref() {
271            let new_path_candidate = format!("{}/{}", path_candidate, index_file);
272            if E::get(&new_path_candidate).is_some() {
273                return GetFileResult {
274                    path: Cow::Owned(new_path_candidate),
275                    file: None,
276                    should_redirect: Some(format!("/{}/", path_candidate)),
277                    compression_method: CompressionMethod::Identity,
278                    is_fallback: false,
279                };
280            }
281        }
282
283        let mut file = E::get(&path_candidate);
284        let mut compressed_method = CompressionMethod::Identity;
285
286        if file.is_some() {
287            for one_method in acceptable_encoding {
288                if let Some(x) = E::get(&format!("{}{}", path_candidate, one_method.extension())) {
289                    file = Some(x);
290                    compressed_method = *one_method;
291                    break;
292                }
293            }
294        }
295
296        GetFileResult {
297            path: path_candidate,
298            file,
299            should_redirect: None,
300            compression_method: compressed_method,
301            is_fallback: false,
302        }
303    }
304
305    fn get_file_with_fallback<'a, 'b: 'a>(
306        &'b self,
307        path: &'a str,
308        acceptable_encoding: &[CompressionMethod],
309    ) -> GetFileResult<'a> {
310        let first_try = self.get_file(path, acceptable_encoding);
311        if first_try.file.is_some() || first_try.should_redirect.is_some() {
312            return first_try;
313        }
314        if let Some(fallback_file) = self.fallback_file.as_ref().as_ref() {
315            if fallback_file != path && self.fallback_behavior == FallbackBehavior::Redirect {
316                return GetFileResult {
317                    path: Cow::Borrowed(path),
318                    file: None,
319                    should_redirect: Some(format!("/{}", fallback_file)),
320                    compression_method: CompressionMethod::Identity,
321                    is_fallback: true,
322                };
323            }
324            let mut fallback_try = self.get_file(fallback_file, acceptable_encoding);
325            fallback_try.is_fallback = true;
326            if fallback_try.file.is_some() {
327                return fallback_try;
328            }
329        }
330        GetFileResult {
331            path: Cow::Borrowed("404.html"),
332            file: DefaultFallback::get("404.html"),
333            should_redirect: None,
334            compression_method: CompressionMethod::Identity,
335            is_fallback: true,
336        }
337    }
338}
339
340impl<E: RustEmbed, T> Future for ServeFuture<E, T> {
341    type Output = Result<Response<Full<Bytes>>, Infallible>;
342
343    fn poll(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
344        // Accept only GET and HEAD method
345        if self.request.method() != http::Method::GET && self.request.method() != http::Method::HEAD
346        {
347            return Poll::Ready(Ok(Response::builder()
348                .status(StatusCode::METHOD_NOT_ALLOWED)
349                .header(http::header::CONTENT_TYPE, "text/plain")
350                .body(Full::new(Bytes::from("Method not allowed")))
351                .unwrap()));
352        }
353
354        // get embedded file for the requested path
355        let (path, file, compression_method, is_fallback) = match self.get_file_with_fallback(
356            self.request.uri().path(),
357            &from_acceptable_encoding(
358                self.request
359                    .headers()
360                    .get(http::header::ACCEPT_ENCODING)
361                    .and_then(|x| x.to_str().ok()),
362            ),
363        ) {
364            // if the file is found, return it
365            GetFileResult {
366                path,
367                file: Some(file),
368                should_redirect: None,
369                compression_method,
370                is_fallback,
371            } => (path, file, compression_method, is_fallback),
372            // if the path is a directory and the client does not have a trailing slash, redirect to the directory with a trailing slash
373            GetFileResult {
374                path: _,
375                file: _,
376                should_redirect: Some(should_redirect),
377                compression_method: _,
378                is_fallback,
379            } => {
380                return Poll::Ready(Ok(Response::builder()
381                    .status(if is_fallback {
382                        StatusCode::TEMPORARY_REDIRECT
383                    } else {
384                        StatusCode::MOVED_PERMANENTLY
385                    })
386                    .header(http::header::LOCATION, should_redirect)
387                    .header(http::header::CONTENT_TYPE, "text/plain")
388                    .body(Full::new(if is_fallback {
389                        Bytes::from("Temporary redirect")
390                    } else {
391                        Bytes::from("Moved permanently")
392                    }))
393                    .unwrap()));
394            }
395            // if the file is not found, return 404
396            _ => {
397                unreachable!();
398            }
399        };
400
401        // If the client has the same file, return 304
402        if !is_fallback
403            && self
404                .request
405                .headers()
406                .get(http::header::IF_NONE_MATCH)
407                .and_then(|value| value.to_str().ok().map(|value| value.trim_matches('"')))
408                == Some(hash_to_string(&file.metadata.sha256_hash()).as_str())
409        {
410            return Poll::Ready(Ok(Response::builder()
411                .status(StatusCode::NOT_MODIFIED)
412                .body(Full::new(Bytes::from("")))
413                .unwrap()));
414        }
415
416        // build response and set headers
417        let mut response_builder = Response::builder()
418            .header(
419                http::header::CONTENT_TYPE,
420                mime_guess::from_path(path.as_ref())
421                    .first_or_octet_stream()
422                    .to_string(),
423            )
424            .header(
425                http::header::ETAG,
426                hash_to_string(&file.metadata.sha256_hash()),
427            );
428
429        match compression_method {
430            CompressionMethod::Identity => {}
431            CompressionMethod::Brotli => {
432                response_builder = response_builder.header(http::header::CONTENT_ENCODING, "br");
433            }
434            CompressionMethod::Gzip => {
435                response_builder = response_builder.header(http::header::CONTENT_ENCODING, "gzip");
436            }
437            CompressionMethod::Zlib => {
438                response_builder =
439                    response_builder.header(http::header::CONTENT_ENCODING, "deflate");
440            }
441        }
442
443        if let Some(last_modified) = file.metadata.last_modified() {
444            response_builder =
445                response_builder.header(http::header::LAST_MODIFIED, date_to_string(last_modified));
446        }
447
448        if is_fallback && self.fallback_behavior != FallbackBehavior::Ok {
449            response_builder = response_builder.status(StatusCode::NOT_FOUND);
450        } else {
451            response_builder = response_builder.status(StatusCode::OK);
452        }
453
454        Poll::Ready(Ok(response_builder
455            .body(Full::new(cow_to_bytes(file.data)))
456            .unwrap()))
457    }
458}
459
460fn hash_to_string(hash: &[u8; 32]) -> String {
461    let mut s = String::with_capacity(64);
462    for byte in hash {
463        s.push_str(&format!("{:02x}", byte));
464    }
465    s
466}
467
468fn date_to_string(date: u64) -> String {
469    let format = format_description::parse(
470        "[weekday repr:short], [day] [month repr:short] [year] [hour]:[minute]:[second] GMT",
471    )
472    .unwrap();
473    OffsetDateTime::from_unix_timestamp(date as i64)
474        .unwrap()
475        .format(&format)
476        .unwrap()
477}
478
479#[cfg(test)]
480mod test;