1use std::{borrow::Cow, convert::Infallible, future::Future, pin::Pin, sync::Arc, task::Poll};
46
47use axum_core::body::Body;
48use axum_core::extract::Request;
49use axum_core::response::Response;
50use chrono::{DateTime, Utc};
51use http::StatusCode;
52use rust_embed::RustEmbed;
53use tower_service::Service;
54
55#[derive(Clone, RustEmbed)]
56#[folder = "src/assets"]
57struct DefaultFallback;
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
61pub enum FallbackBehavior {
62 NotFound,
64 Redirect,
66 Ok,
68}
69
70#[derive(Debug, Clone)]
96pub struct ServeEmbed<E: RustEmbed + Clone> {
97 _phantom: std::marker::PhantomData<E>,
98 fallback_file: Arc<Option<String>>,
99 fallback_behavior: FallbackBehavior,
100 index_file: Arc<Option<String>>,
101}
102
103impl<E: RustEmbed + Clone> ServeEmbed<E> {
104 pub fn new() -> Self {
111 Self::with_parameters(
112 None,
113 FallbackBehavior::NotFound,
114 Some("index.html".to_owned()),
115 )
116 }
117
118 pub fn with_parameters(
128 fallback_file: Option<String>,
129 fallback_behavior: FallbackBehavior,
130 index_file: Option<String>,
131 ) -> Self {
132 Self {
133 _phantom: std::marker::PhantomData,
134 fallback_file: Arc::new(fallback_file),
135 fallback_behavior,
136 index_file: Arc::new(index_file),
137 }
138 }
139}
140
141impl<E: RustEmbed + Clone, T: Send + 'static> Service<http::request::Request<T>> for ServeEmbed<E> {
142 type Response = Response;
143 type Error = Infallible;
144 type Future = ServeFuture<E, T>;
145
146 fn poll_ready(
147 &mut self,
148 _cx: &mut std::task::Context<'_>,
149 ) -> std::task::Poll<Result<(), Self::Error>> {
150 Poll::Ready(Ok(()))
151 }
152
153 fn call(&mut self, req: http::request::Request<T>) -> Self::Future {
154 ServeFuture {
155 _phantom: std::marker::PhantomData,
156 fallback_behavior: self.fallback_behavior,
157 fallback_file: self.fallback_file.clone(),
158 index_file: self.index_file.clone(),
159 request: req,
160 }
161 }
162}
163
164#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
165enum CompressionMethod {
166 Identity,
167 Brotli,
168 Gzip,
169 Zlib,
170}
171
172impl CompressionMethod {
173 fn extension(self) -> &'static str {
174 match self {
175 Self::Identity => "",
176 Self::Brotli => ".br",
177 Self::Gzip => ".gz",
178 Self::Zlib => ".zz",
179 }
180 }
181}
182
183fn from_acceptable_encoding(acceptable_encoding: Option<&str>) -> Vec<CompressionMethod> {
184 let mut compression_methods = Vec::new();
185
186 let mut identity_found = false;
187 for acceptable_encoding in acceptable_encoding.unwrap_or("").split(',') {
188 let acceptable_encoding = acceptable_encoding.trim().split(';').next().unwrap();
189 if acceptable_encoding == "br" {
190 compression_methods.push(CompressionMethod::Brotli);
191 } else if acceptable_encoding == "gzip" {
192 compression_methods.push(CompressionMethod::Gzip);
193 } else if acceptable_encoding == "deflate" {
194 compression_methods.push(CompressionMethod::Zlib);
195 } else if acceptable_encoding == "identity" {
196 compression_methods.push(CompressionMethod::Identity);
197 identity_found = true;
198 }
199 }
200
201 if !identity_found {
202 compression_methods.push(CompressionMethod::Identity);
203 }
204
205 compression_methods
206}
207
208struct GetFileResult<'a> {
209 path: Cow<'a, str>,
210 file: Option<rust_embed::EmbeddedFile>,
211 should_redirect: Option<String>,
212 compression_method: CompressionMethod,
213 is_fallback: bool,
214}
215
216#[derive(Debug, Clone)]
220pub struct ServeFuture<E: RustEmbed, T> {
221 _phantom: std::marker::PhantomData<E>,
222 fallback_behavior: FallbackBehavior,
223 fallback_file: Arc<Option<String>>,
224 index_file: Arc<Option<String>>,
225 request: Request<T>,
226}
227
228impl<E: RustEmbed, T> ServeFuture<E, T> {
229 fn get_file<'a>(
238 &self,
239 path: &'a str,
240 acceptable_encoding: &[CompressionMethod],
241 ) -> GetFileResult<'a> {
242 let mut path_candidate = Cow::Borrowed(path.trim_start_matches('/'));
243
244 if path_candidate == "" {
245 if let Some(index_file) = self.index_file.as_ref() {
246 path_candidate = Cow::Owned(index_file.to_string());
247 }
248 } else if path_candidate.ends_with('/') {
249 if let Some(index_file) = self.index_file.as_ref().as_ref() {
250 let new_path_candidate = format!("{}{}", path_candidate, index_file);
251 if E::get(&new_path_candidate).is_some() {
252 path_candidate = Cow::Owned(new_path_candidate);
253 }
254 }
255 } else {
256 if let Some(index_file) = self.index_file.as_ref().as_ref() {
257 let new_path_candidate = format!("{}/{}", path_candidate, index_file);
258 if E::get(&new_path_candidate).is_some() {
259 return GetFileResult {
260 path: Cow::Owned(new_path_candidate),
261 file: None,
262 should_redirect: Some(format!("/{}/", path_candidate)),
263 compression_method: CompressionMethod::Identity,
264 is_fallback: false,
265 };
266 }
267 }
268 }
269
270 let mut file = E::get(&path_candidate);
271 let mut compressed_method = CompressionMethod::Identity;
272
273 if file.is_some() {
274 for one_method in acceptable_encoding {
275 if let Some(x) = E::get(&format!("{}{}", path_candidate, one_method.extension())) {
276 file = Some(x);
277 compressed_method = *one_method;
278 break;
279 }
280 }
281 }
282
283 GetFileResult {
284 path: path_candidate,
285 file,
286 should_redirect: None,
287 compression_method: compressed_method,
288 is_fallback: false,
289 }
290 }
291
292 fn get_file_with_fallback<'a, 'b: 'a>(
293 &'b self,
294 path: &'a str,
295 acceptable_encoding: &[CompressionMethod],
296 ) -> GetFileResult<'a> {
297 let first_try = self.get_file(path, acceptable_encoding);
298 if first_try.file.is_some() || first_try.should_redirect.is_some() {
299 return first_try;
300 }
301 if let Some(fallback_file) = self.fallback_file.as_ref().as_ref() {
302 if fallback_file != path && self.fallback_behavior == FallbackBehavior::Redirect {
303 return GetFileResult {
304 path: Cow::Borrowed(path),
305 file: None,
306 should_redirect: Some(format!("/{}", fallback_file)),
307 compression_method: CompressionMethod::Identity,
308 is_fallback: true,
309 };
310 }
311 let mut fallback_try = self.get_file(fallback_file, acceptable_encoding);
312 fallback_try.is_fallback = true;
313 if fallback_try.file.is_some() {
314 return fallback_try;
315 }
316 }
317 GetFileResult {
318 path: Cow::Borrowed("404.html"),
319 file: DefaultFallback::get("404.html"),
320 should_redirect: None,
321 compression_method: CompressionMethod::Identity,
322 is_fallback: true,
323 }
324 }
325}
326
327impl<E: RustEmbed, T> Future for ServeFuture<E, T> {
328 type Output = Result<Response<Body>, Infallible>;
329
330 fn poll(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
331 if self.request.method() != http::Method::GET && self.request.method() != http::Method::HEAD
333 {
334 return Poll::Ready(Ok(Response::builder()
335 .status(StatusCode::METHOD_NOT_ALLOWED)
336 .header(http::header::CONTENT_TYPE, "text/plain")
337 .body(Body::from("Method not allowed"))
338 .unwrap()));
339 }
340
341 let (path, file, compression_method, is_fallback) = match self.get_file_with_fallback(
343 self.request.uri().path(),
344 &from_acceptable_encoding(
345 self.request
346 .headers()
347 .get(http::header::ACCEPT_ENCODING)
348 .map(|x| x.to_str().ok())
349 .flatten(),
350 ),
351 ) {
352 GetFileResult {
354 path,
355 file: Some(file),
356 should_redirect: None,
357 compression_method,
358 is_fallback,
359 } => (path, file, compression_method, is_fallback),
360 GetFileResult {
362 path: _,
363 file: _,
364 should_redirect: Some(should_redirect),
365 compression_method: _,
366 is_fallback,
367 } => {
368 return Poll::Ready(Ok(Response::builder()
369 .status(if is_fallback {
370 StatusCode::TEMPORARY_REDIRECT
371 } else {
372 StatusCode::MOVED_PERMANENTLY
373 })
374 .header(http::header::LOCATION, should_redirect)
375 .header(http::header::CONTENT_TYPE, "text/plain")
376 .body(if is_fallback {
377 Body::from("Temporary redirect")
378 } else {
379 Body::from("Moved permanently")
380 })
381 .unwrap()));
382 }
383 _ => {
385 unreachable!();
386 }
387 };
388
389 if !is_fallback
391 && self
392 .request
393 .headers()
394 .get(http::header::IF_NONE_MATCH)
395 .and_then(|value| {
396 value
397 .to_str()
398 .ok()
399 .and_then(|value| Some(value.trim_matches('"')))
400 })
401 == Some(hash_to_string(&file.metadata.sha256_hash()).as_str())
402 {
403 return Poll::Ready(Ok(Response::builder()
404 .status(StatusCode::NOT_MODIFIED)
405 .body(Body::empty())
406 .unwrap()));
407 }
408
409 let mut response_builder = Response::builder()
411 .header(
412 http::header::CONTENT_TYPE,
413 mime_guess::from_path(path.as_ref())
414 .first_or_octet_stream()
415 .to_string(),
416 )
417 .header(
418 http::header::ETAG,
419 hash_to_string(&file.metadata.sha256_hash()),
420 );
421
422 match compression_method {
423 CompressionMethod::Identity => {}
424 CompressionMethod::Brotli => {
425 response_builder = response_builder.header(http::header::CONTENT_ENCODING, "br");
426 }
427 CompressionMethod::Gzip => {
428 response_builder = response_builder.header(http::header::CONTENT_ENCODING, "gzip");
429 }
430 CompressionMethod::Zlib => {
431 response_builder =
432 response_builder.header(http::header::CONTENT_ENCODING, "deflate");
433 }
434 }
435
436 if let Some(last_modified) = file.metadata.last_modified() {
437 response_builder =
438 response_builder.header(http::header::LAST_MODIFIED, date_to_string(last_modified));
439 }
440
441 if is_fallback && self.fallback_behavior != FallbackBehavior::Ok {
442 response_builder = response_builder.status(StatusCode::NOT_FOUND);
443 } else {
444 response_builder = response_builder.status(StatusCode::OK);
445 }
446
447 Poll::Ready(Ok(response_builder
448 .body(file.data.to_owned().into())
449 .unwrap()))
450 }
451}
452
453fn hash_to_string(hash: &[u8; 32]) -> String {
454 let mut s = String::with_capacity(64);
455 for byte in hash {
456 s.push_str(&format!("{:02x}", byte));
457 }
458 s
459}
460
461fn date_to_string(date: u64) -> String {
462 DateTime::<Utc>::from_timestamp(date as i64, 0)
463 .unwrap()
464 .format("%a, %d %b %Y %H:%M:%S GMT")
465 .to_string()
466}
467
468#[cfg(test)]
469mod test;