1use std::{borrow::Cow, convert::Infallible, future::Future, pin::Pin, sync::Arc, task::Poll};
49
50use bytes::Bytes;
51use http::{Request, Response, StatusCode};
52use http_body_util::Full;
53use rust_embed::RustEmbed;
54use time::{OffsetDateTime, format_description};
55use tower_service::Service;
56
57#[derive(Clone, RustEmbed)]
58#[folder = "src/assets"]
59struct DefaultFallback;
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
63pub enum FallbackBehavior {
64 NotFound,
66 Redirect,
68 Ok,
70}
71
72#[derive(Debug, Clone)]
98pub struct ServeEmbed<E: RustEmbed + Clone> {
99 _phantom: std::marker::PhantomData<E>,
100 fallback_file: Arc<Option<String>>,
101 fallback_behavior: FallbackBehavior,
102 index_file: Arc<Option<String>>,
103}
104
105impl<E: RustEmbed + Clone> ServeEmbed<E> {
106 pub fn new() -> Self {
113 Self::with_parameters(
114 None,
115 FallbackBehavior::NotFound,
116 Some("index.html".to_owned()),
117 )
118 }
119
120 pub fn with_parameters(
130 fallback_file: Option<String>,
131 fallback_behavior: FallbackBehavior,
132 index_file: Option<String>,
133 ) -> Self {
134 Self {
135 _phantom: std::marker::PhantomData,
136 fallback_file: Arc::new(fallback_file),
137 fallback_behavior,
138 index_file: Arc::new(index_file),
139 }
140 }
141}
142
143impl<E: RustEmbed + Clone> Default for ServeEmbed<E> {
144 fn default() -> Self {
145 Self::new()
146 }
147}
148
149impl<E: RustEmbed + Clone, T: Send + 'static> Service<http::request::Request<T>> for ServeEmbed<E> {
150 type Response = http::Response<Full<Bytes>>;
151 type Error = Infallible;
152 type Future = ServeFuture<E, T>;
153
154 fn poll_ready(
155 &mut self,
156 _cx: &mut std::task::Context<'_>,
157 ) -> std::task::Poll<Result<(), Self::Error>> {
158 Poll::Ready(Ok(()))
159 }
160
161 fn call(&mut self, req: http::request::Request<T>) -> Self::Future {
162 ServeFuture {
163 _phantom: std::marker::PhantomData,
164 fallback_behavior: self.fallback_behavior,
165 fallback_file: self.fallback_file.clone(),
166 index_file: self.index_file.clone(),
167 request: req,
168 }
169 }
170}
171
172#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
173enum CompressionMethod {
174 Identity,
175 Brotli,
176 Gzip,
177 Zlib,
178}
179
180impl CompressionMethod {
181 const fn extension(self) -> &'static str {
182 match self {
183 Self::Identity => "",
184 Self::Brotli => ".br",
185 Self::Gzip => ".gz",
186 Self::Zlib => ".zz",
187 }
188 }
189}
190
191fn from_acceptable_encoding(acceptable_encoding: Option<&str>) -> Vec<CompressionMethod> {
192 let mut compression_methods = Vec::new();
193
194 let mut identity_found = false;
195 for acceptable_encoding in acceptable_encoding.unwrap_or("").split(',') {
196 let acceptable_encoding = acceptable_encoding.trim().split(';').next().unwrap();
197 if acceptable_encoding == "br" {
198 compression_methods.push(CompressionMethod::Brotli);
199 } else if acceptable_encoding == "gzip" {
200 compression_methods.push(CompressionMethod::Gzip);
201 } else if acceptable_encoding == "deflate" {
202 compression_methods.push(CompressionMethod::Zlib);
203 } else if acceptable_encoding == "identity" {
204 compression_methods.push(CompressionMethod::Identity);
205 identity_found = true;
206 }
207 }
208
209 if !identity_found {
210 compression_methods.push(CompressionMethod::Identity);
211 }
212
213 compression_methods
214}
215
216fn cow_to_bytes(cow: Cow<'static, [u8]>) -> Bytes {
217 match cow {
218 Cow::Borrowed(x) => Bytes::from(x),
219 Cow::Owned(x) => Bytes::from(x),
220 }
221}
222
223struct GetFileResult<'a> {
224 path: Cow<'a, str>,
225 file: Option<rust_embed::EmbeddedFile>,
226 should_redirect: Option<String>,
227 compression_method: CompressionMethod,
228 is_fallback: bool,
229}
230
231#[derive(Debug, Clone)]
235pub struct ServeFuture<E: RustEmbed, T> {
236 _phantom: std::marker::PhantomData<E>,
237 fallback_behavior: FallbackBehavior,
238 fallback_file: Arc<Option<String>>,
239 index_file: Arc<Option<String>>,
240 request: Request<T>,
241}
242
243impl<E: RustEmbed, T> ServeFuture<E, T> {
244 fn get_file<'a>(
253 &self,
254 path: &'a str,
255 acceptable_encoding: &[CompressionMethod],
256 ) -> GetFileResult<'a> {
257 let mut path_candidate = Cow::Borrowed(path.trim_start_matches('/'));
258
259 if path_candidate.is_empty() {
260 if let Some(index_file) = self.index_file.as_ref() {
261 path_candidate = Cow::Owned(index_file.to_string());
262 }
263 } else if path_candidate.ends_with('/') {
264 if let Some(index_file) = self.index_file.as_ref().as_ref() {
265 let new_path_candidate = format!("{}{}", path_candidate, index_file);
266 if E::get(&new_path_candidate).is_some() {
267 path_candidate = Cow::Owned(new_path_candidate);
268 }
269 }
270 } else if let Some(index_file) = self.index_file.as_ref().as_ref() {
271 let new_path_candidate = format!("{}/{}", path_candidate, index_file);
272 if E::get(&new_path_candidate).is_some() {
273 return GetFileResult {
274 path: Cow::Owned(new_path_candidate),
275 file: None,
276 should_redirect: Some(format!("/{}/", path_candidate)),
277 compression_method: CompressionMethod::Identity,
278 is_fallback: false,
279 };
280 }
281 }
282
283 let mut file = E::get(&path_candidate);
284 let mut compressed_method = CompressionMethod::Identity;
285
286 if file.is_some() {
287 for one_method in acceptable_encoding {
288 if let Some(x) = E::get(&format!("{}{}", path_candidate, one_method.extension())) {
289 file = Some(x);
290 compressed_method = *one_method;
291 break;
292 }
293 }
294 }
295
296 GetFileResult {
297 path: path_candidate,
298 file,
299 should_redirect: None,
300 compression_method: compressed_method,
301 is_fallback: false,
302 }
303 }
304
305 fn get_file_with_fallback<'a, 'b: 'a>(
306 &'b self,
307 path: &'a str,
308 acceptable_encoding: &[CompressionMethod],
309 ) -> GetFileResult<'a> {
310 let first_try = self.get_file(path, acceptable_encoding);
311 if first_try.file.is_some() || first_try.should_redirect.is_some() {
312 return first_try;
313 }
314 if let Some(fallback_file) = self.fallback_file.as_ref().as_ref() {
315 if fallback_file != path && self.fallback_behavior == FallbackBehavior::Redirect {
316 return GetFileResult {
317 path: Cow::Borrowed(path),
318 file: None,
319 should_redirect: Some(format!("/{}", fallback_file)),
320 compression_method: CompressionMethod::Identity,
321 is_fallback: true,
322 };
323 }
324 let mut fallback_try = self.get_file(fallback_file, acceptable_encoding);
325 fallback_try.is_fallback = true;
326 if fallback_try.file.is_some() {
327 return fallback_try;
328 }
329 }
330 GetFileResult {
331 path: Cow::Borrowed("404.html"),
332 file: DefaultFallback::get("404.html"),
333 should_redirect: None,
334 compression_method: CompressionMethod::Identity,
335 is_fallback: true,
336 }
337 }
338}
339
340impl<E: RustEmbed, T> Future for ServeFuture<E, T> {
341 type Output = Result<Response<Full<Bytes>>, Infallible>;
342
343 fn poll(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
344 if self.request.method() != http::Method::GET && self.request.method() != http::Method::HEAD
346 {
347 return Poll::Ready(Ok(Response::builder()
348 .status(StatusCode::METHOD_NOT_ALLOWED)
349 .header(http::header::CONTENT_TYPE, "text/plain")
350 .body(Full::new(Bytes::from("Method not allowed")))
351 .unwrap()));
352 }
353
354 let (path, file, compression_method, is_fallback) = match self.get_file_with_fallback(
356 self.request.uri().path(),
357 &from_acceptable_encoding(
358 self.request
359 .headers()
360 .get(http::header::ACCEPT_ENCODING)
361 .and_then(|x| x.to_str().ok()),
362 ),
363 ) {
364 GetFileResult {
366 path,
367 file: Some(file),
368 should_redirect: None,
369 compression_method,
370 is_fallback,
371 } => (path, file, compression_method, is_fallback),
372 GetFileResult {
374 path: _,
375 file: _,
376 should_redirect: Some(should_redirect),
377 compression_method: _,
378 is_fallback,
379 } => {
380 return Poll::Ready(Ok(Response::builder()
381 .status(if is_fallback {
382 StatusCode::TEMPORARY_REDIRECT
383 } else {
384 StatusCode::MOVED_PERMANENTLY
385 })
386 .header(http::header::LOCATION, should_redirect)
387 .header(http::header::CONTENT_TYPE, "text/plain")
388 .body(Full::new(if is_fallback {
389 Bytes::from("Temporary redirect")
390 } else {
391 Bytes::from("Moved permanently")
392 }))
393 .unwrap()));
394 }
395 _ => {
397 unreachable!();
398 }
399 };
400
401 if !is_fallback
403 && self
404 .request
405 .headers()
406 .get(http::header::IF_NONE_MATCH)
407 .and_then(|value| value.to_str().ok().map(|value| value.trim_matches('"')))
408 == Some(hash_to_string(&file.metadata.sha256_hash()).as_str())
409 {
410 return Poll::Ready(Ok(Response::builder()
411 .status(StatusCode::NOT_MODIFIED)
412 .body(Full::new(Bytes::from("")))
413 .unwrap()));
414 }
415
416 let mut response_builder = Response::builder()
418 .header(
419 http::header::CONTENT_TYPE,
420 mime_guess::from_path(path.as_ref())
421 .first_or_octet_stream()
422 .to_string(),
423 )
424 .header(
425 http::header::ETAG,
426 hash_to_string(&file.metadata.sha256_hash()),
427 );
428
429 match compression_method {
430 CompressionMethod::Identity => {}
431 CompressionMethod::Brotli => {
432 response_builder = response_builder.header(http::header::CONTENT_ENCODING, "br");
433 }
434 CompressionMethod::Gzip => {
435 response_builder = response_builder.header(http::header::CONTENT_ENCODING, "gzip");
436 }
437 CompressionMethod::Zlib => {
438 response_builder =
439 response_builder.header(http::header::CONTENT_ENCODING, "deflate");
440 }
441 }
442
443 if let Some(last_modified) = file.metadata.last_modified() {
444 response_builder =
445 response_builder.header(http::header::LAST_MODIFIED, date_to_string(last_modified));
446 }
447
448 if is_fallback && self.fallback_behavior != FallbackBehavior::Ok {
449 response_builder = response_builder.status(StatusCode::NOT_FOUND);
450 } else {
451 response_builder = response_builder.status(StatusCode::OK);
452 }
453
454 Poll::Ready(Ok(response_builder
455 .body(Full::new(cow_to_bytes(file.data)))
456 .unwrap()))
457 }
458}
459
460fn hash_to_string(hash: &[u8; 32]) -> String {
461 let mut s = String::with_capacity(64);
462 for byte in hash {
463 s.push_str(&format!("{:02x}", byte));
464 }
465 s
466}
467
468fn date_to_string(date: u64) -> String {
469 let format = format_description::parse(
470 "[weekday repr:short], [day] [month repr:short] [year] [hour]:[minute]:[second] GMT",
471 )
472 .unwrap();
473 OffsetDateTime::from_unix_timestamp(date as i64)
474 .unwrap()
475 .format(&format)
476 .unwrap()
477}
478
479#[cfg(test)]
480mod test;