salvo_serve_static/
embed.rs

1use std::borrow::Cow;
2use std::fmt::{self, Debug, Formatter};
3use std::marker::PhantomData;
4
5use rust_embed::{EmbeddedFile, Metadata, RustEmbed};
6use salvo_core::handler::Handler;
7use salvo_core::http::header::{
8    ACCEPT_RANGES, CONTENT_LENGTH, CONTENT_TYPE, ETAG, IF_NONE_MATCH, RANGE,
9};
10use salvo_core::http::headers::{ContentLength, ContentRange, HeaderMapExt};
11use salvo_core::http::mime::fill_mime_charset_if_need;
12use salvo_core::http::{HeaderValue, HttpRange, Mime, Request, Response, StatusCode};
13use salvo_core::{Depot, FlowCtrl, IntoVecString, async_trait};
14
15use super::{decode_url_path_safely, format_url_path_safely, redirect_to_dir_url};
16
17/// Handler that serves embedded files using `rust-embed`.
18///
19/// This handler allows serving files embedded in the application binary,
20/// which is useful for distributing a self-contained executable.
21#[non_exhaustive]
22#[derive(Default)]
23pub struct StaticEmbed<T> {
24    _assets: PhantomData<T>,
25    /// Default file names list (e.g., "index.html")
26    pub defaults: Vec<String>,
27    /// Fallback file name used when the requested file isn't found
28    pub fallback: Option<String>,
29}
30impl<T: Debug> Debug for StaticEmbed<T> {
31    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
32        f.debug_struct("StaticEmbed")
33            .field("defaults", &self.defaults)
34            .field("fallback", &self.fallback)
35            .finish()
36    }
37}
38
39/// Create a new `StaticEmbed` handler for the given embedded asset type.
40#[inline]
41#[must_use]
42pub fn static_embed<T: RustEmbed>() -> StaticEmbed<T> {
43    StaticEmbed {
44        _assets: PhantomData,
45        defaults: vec![],
46        fallback: None,
47    }
48}
49
50/// Render an [`EmbeddedFile`] to the [`Response`].
51#[inline]
52pub fn render_embedded_file(
53    file: EmbeddedFile,
54    req: &Request,
55    res: &mut Response,
56    mime: Option<Mime>,
57) {
58    let EmbeddedFile { data, metadata } = file;
59    render_embedded_data(data, &metadata, req, res, mime);
60}
61
62fn render_embedded_data(
63    data: Cow<'static, [u8]>,
64    metadata: &Metadata,
65    req: &Request,
66    res: &mut Response,
67    mime: Option<Mime>,
68) {
69    // Determine Content-Type once
70    let content_type =
71        if let Some(mut mime) = mime.or_else(|| mime_infer::from_path(req.uri().path()).first()) {
72            fill_mime_charset_if_need(&mut mime, &data);
73            mime
74        } else {
75            mime::APPLICATION_OCTET_STREAM
76        };
77
78    res.headers_mut().insert(
79        CONTENT_TYPE,
80        content_type
81            .as_ref()
82            .parse()
83            .unwrap_or_else(|_| HeaderValue::from_static("application/octet-stream")),
84    );
85
86    // ETag generation and If-None-Match check
87    let hash = hex::encode(metadata.sha256_hash());
88    if req
89        .headers()
90        .get(IF_NONE_MATCH)
91        .map(|etag| etag.to_str().unwrap_or("000000").eq(&hash))
92        .unwrap_or(false)
93    {
94        res.status_code(StatusCode::NOT_MODIFIED);
95        return;
96    }
97
98    // Set ETag for all successful responses (200 or 206)
99    if let Ok(etag_val) = hash.parse() {
100        res.headers_mut().insert(ETAG, etag_val);
101    } else {
102        tracing::error!("Failed to parse etag hash: {}", hash);
103    }
104
105    // Indicate that byte ranges are accepted
106    res.headers_mut()
107        .insert(ACCEPT_RANGES, HeaderValue::from_static("bytes"));
108
109    let total_data_len = data.len() as u64;
110    let mut is_partial_content = false;
111    let mut range_to_send: Option<(u64, u64)> = None; // (start_offset, length_of_part)
112
113    let req_headers = req.headers();
114    if let Some(range_header_val) = req_headers.get(RANGE) {
115        if let Ok(range_str) = range_header_val.to_str() {
116            match HttpRange::parse(range_str, total_data_len) {
117                Ok(ranges) if !ranges.is_empty() => {
118                    // Successfully parsed and satisfiable range(s). We only handle the first one.
119                    let first_range = &ranges[0]; // HttpRange ensures start + length <= total_data_len
120                    is_partial_content = true;
121                    range_to_send = Some((first_range.start, first_range.length));
122
123                    res.status_code(StatusCode::PARTIAL_CONTENT);
124                    match ContentRange::bytes(
125                        first_range.start..(first_range.start + first_range.length),
126                        total_data_len,
127                    ) {
128                        Ok(content_range_header) => {
129                            res.headers_mut().typed_insert(content_range_header);
130                        }
131                        Err(e) => {
132                            tracing::error!(error = ?e, "Failed to create Content-Range header");
133                            res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
134                            return;
135                        }
136                    }
137                }
138                Err(_) => {
139                    // HttpRange::parse returns Err if the range is unsatisfiable or malformed.
140                    res.headers_mut()
141                        .typed_insert(ContentRange::unsatisfied_bytes(total_data_len));
142                    res.status_code(StatusCode::RANGE_NOT_SATISFIABLE);
143                    return;
144                }
145                Ok(_) => {
146                    // Parsed, but no valid ranges. Treat as full content.
147                    // is_partial_content remains false.
148                }
149            }
150        } else {
151            // Failed to convert Range header to string (e.g., invalid UTF-8)
152            res.status_code(StatusCode::BAD_REQUEST);
153            return;
154        }
155    }
156
157    if is_partial_content {
158        if let Some((offset, length)) = range_to_send {
159            // Ensure the range is valid before slicing. HttpRange::parse should guarantee this.
160            let end_offset = offset
161                .checked_add(length)
162                .expect("Range calculation overflowed");
163            if end_offset <= total_data_len {
164                // Check to prevent panic on slice
165                let partial_data_vec = data[offset as usize..end_offset as usize].to_vec();
166                res.headers_mut().typed_insert(ContentLength(length));
167                let _ = res.write_body(partial_data_vec); // write_body can take Vec<u8>
168            } else {
169                // This should ideally be caught by HttpRange::parse or ContentRange::bytes
170                tracing::error!("Calculated range exceeds data bounds after HttpRange::parse");
171                res.headers_mut()
172                    .typed_insert(ContentRange::unsatisfied_bytes(total_data_len));
173                res.status_code(StatusCode::RANGE_NOT_SATISFIABLE);
174                // Clear content length if we are not sending a body for this error
175                res.headers_mut().remove(CONTENT_LENGTH);
176            }
177        } else {
178            // Should not happen if is_partial_content is true.
179            tracing::error!("is_partial_content is true but range_to_send is None");
180            res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
181        }
182    } else {
183        // Serve full content
184        res.status_code(StatusCode::OK); // Ensure OK status
185        res.headers_mut()
186            .typed_insert(ContentLength(total_data_len));
187        match data {
188            Cow::Borrowed(d) => {
189                let _ = res.write_body(d);
190            }
191            Cow::Owned(o) => {
192                let _ = res.write_body(o);
193            }
194        }
195    }
196}
197
198impl<T> StaticEmbed<T>
199where
200    T: RustEmbed + Send + Sync + 'static,
201{
202    /// Create a new `StaticEmbed`.
203    #[inline]
204    #[must_use]
205    pub fn new() -> Self {
206        Self {
207            _assets: PhantomData,
208            defaults: vec![],
209            fallback: None,
210        }
211    }
212
213    /// Create a new `StaticEmbed` with defaults.
214    #[inline]
215    #[must_use]
216    pub fn defaults(mut self, defaults: impl IntoVecString) -> Self {
217        self.defaults = defaults.into_vec_string();
218        self
219    }
220
221    /// Create a new `StaticEmbed` with fallback.
222    #[inline]
223    #[must_use]
224    pub fn fallback(mut self, fallback: impl Into<String>) -> Self {
225        self.fallback = Some(fallback.into());
226        self
227    }
228}
229#[async_trait]
230impl<T> Handler for StaticEmbed<T>
231where
232    T: RustEmbed + Send + Sync + 'static,
233{
234    async fn handle(
235        &self,
236        req: &mut Request,
237        _depot: &mut Depot,
238        res: &mut Response,
239        _ctrl: &mut FlowCtrl,
240    ) {
241        let req_path = if let Some(rest) = req.params().tail() {
242            rest
243        } else {
244            &*decode_url_path_safely(req.uri().path())
245        };
246        let req_path = format_url_path_safely(req_path);
247        let mut key_path = Cow::Borrowed(&*req_path);
248        let mut embedded_file = T::get(req_path.as_str());
249        if embedded_file.is_none() {
250            for ifile in &self.defaults {
251                let ipath = join_path!(&req_path, ifile);
252                if let Some(file) = T::get(&ipath) {
253                    embedded_file = Some(file);
254                    key_path = Cow::from(ipath);
255                    break;
256                }
257            }
258            if embedded_file.is_some() && !req_path.ends_with('/') && !req_path.is_empty() {
259                redirect_to_dir_url(req.uri(), res);
260                return;
261            }
262        }
263        if embedded_file.is_none() {
264            let fallback = self.fallback.as_deref().unwrap_or_default();
265            if !fallback.is_empty() {
266                if let Some(file) = T::get(fallback) {
267                    embedded_file = Some(file);
268                    key_path = Cow::from(fallback);
269                }
270            }
271        }
272
273        match embedded_file {
274            Some(file) => {
275                let mime = mime_infer::from_path(&*key_path).first();
276                render_embedded_file(file, req, res, mime);
277            }
278            None => {
279                res.status_code(StatusCode::NOT_FOUND);
280            }
281        }
282    }
283}
284
285/// Handler for [`EmbeddedFile`].
286pub struct EmbeddedFileHandler(pub EmbeddedFile);
287impl Debug for EmbeddedFileHandler {
288    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
289        f.debug_struct("EmbeddedFileHandler").finish()
290    }
291}
292
293#[async_trait]
294impl Handler for EmbeddedFileHandler {
295    #[inline]
296    async fn handle(
297        &self,
298        req: &mut Request,
299        _depot: &mut Depot,
300        res: &mut Response,
301        _ctrl: &mut FlowCtrl,
302    ) {
303        render_embedded_data(self.0.data.clone(), &self.0.metadata, req, res, None);
304    }
305}
306
307/// Extension trait for [`EmbeddedFile`].
308pub trait EmbeddedFileExt {
309    /// Render the embedded file.
310    fn render(self, req: &Request, res: &mut Response);
311    /// Create a handler for the embedded file.
312    fn into_handler(self) -> EmbeddedFileHandler;
313}
314
315impl EmbeddedFileExt for EmbeddedFile {
316    #[inline]
317    fn render(self, req: &Request, res: &mut Response) {
318        render_embedded_file(self, req, res, None);
319    }
320    #[inline]
321    fn into_handler(self) -> EmbeddedFileHandler {
322        EmbeddedFileHandler(self)
323    }
324}