1use std::{
2 ffi::OsStr,
3 path::{Path, PathBuf},
4};
5
6use poem::{
7 error::StaticFileError, http::Method, web::StaticFileRequest, Endpoint, FromRequest,
8 IntoResponse, Request, Response, Result,
9};
10
11#[cfg(test)]
12mod test;
13
14#[derive(Debug)]
15pub struct SPAEndpoint {
16 base: PathBuf,
17 index: PathBuf,
18 assets: Vec<PathBuf>,
19}
20
21impl SPAEndpoint {
22 pub fn new(base: impl Into<PathBuf>, index: impl Into<PathBuf>) -> Self {
23 let base_path = base.into();
24 Self {
25 index: base_path.join(index.into()),
26 base: base_path,
27 assets: Vec::new(),
28 }
29 }
30
31 #[must_use]
32 pub fn with_assets(self, asset: impl Into<PathBuf>) -> Self {
33 Self {
34 assets: [self.assets, vec![self.base.join(asset.into())]].concat(),
35 ..self
36 }
37 }
38
39 fn path_allowed(&self, path: &Path) -> bool {
40 path.starts_with(&self.base)
41 }
42
43 fn is_asset(&self, path: &Path) -> bool {
44 self.assets.iter().any(|asset| path.starts_with(asset))
45 }
46}
47
48async fn serve_static(req: &Request, file_path: &PathBuf) -> Result<Response> {
49 Ok(StaticFileRequest::from_request_without_body(req)
50 .await?
51 .create_response(file_path, true)?
52 .into_response())
53}
54
55#[async_trait::async_trait]
56impl Endpoint for SPAEndpoint {
57 type Output = Response;
58
59 async fn call(&self, req: Request) -> Result<Self::Output> {
60 if req.method() != Method::GET {
61 return Err(StaticFileError::MethodNotAllowed(req.method().clone()).into());
62 }
63
64 let path = req
65 .uri()
66 .path()
67 .trim_start_matches('/')
68 .trim_end_matches('/');
69
70 let path = percent_encoding::percent_decode_str(path)
71 .decode_utf8()
72 .map_err(|_| StaticFileError::InvalidPath)?;
73
74 let mut file_path = self.base.clone();
75 for p in Path::new(&*path) {
76 if p == OsStr::new(".") {
77 continue;
78 } else if p == OsStr::new("..") {
79 file_path.pop();
80 } else {
81 file_path.push(p);
82 }
83 }
84
85 if !self.path_allowed(&file_path) {
86 return Err(StaticFileError::Forbidden(file_path.display().to_string()).into());
87 }
88
89 if file_path.exists() && file_path.is_file() {
90 serve_static(&req, &file_path).await
91 } else if self.is_asset(&file_path) {
92 if file_path.exists() {
93 return Err(StaticFileError::Forbidden(file_path.display().to_string()).into());
94 } else {
95 Err(StaticFileError::NotFound.into())
96 }
97 } else {
98 serve_static(&req, &self.index).await
99 }
100 }
101}