poem_spa/
lib.rs

1use std::{
2    ffi::OsStr,
3    path::{Path, PathBuf},
4};
5
6use poem::{
7    error::StaticFileError, http::Method, web::StaticFileRequest, Endpoint, FromRequest,
8    IntoResponse, Request, Response, Result,
9};
10
11#[cfg(test)]
12mod test;
13
14#[derive(Debug)]
15pub struct SPAEndpoint {
16    base: PathBuf,
17    index: PathBuf,
18    assets: Vec<PathBuf>,
19}
20
21impl SPAEndpoint {
22    pub fn new(base: impl Into<PathBuf>, index: impl Into<PathBuf>) -> Self {
23        let base_path = base.into();
24        Self {
25            index: base_path.join(index.into()),
26            base: base_path,
27            assets: Vec::new(),
28        }
29    }
30
31    #[must_use]
32    pub fn with_assets(self, asset: impl Into<PathBuf>) -> Self {
33        Self {
34            assets: [self.assets, vec![self.base.join(asset.into())]].concat(),
35            ..self
36        }
37    }
38
39    fn path_allowed(&self, path: &Path) -> bool {
40        path.starts_with(&self.base)
41    }
42
43    fn is_asset(&self, path: &Path) -> bool {
44        self.assets.iter().any(|asset| path.starts_with(asset))
45    }
46}
47
48async fn serve_static(req: &Request, file_path: &PathBuf) -> Result<Response> {
49    Ok(StaticFileRequest::from_request_without_body(req)
50        .await?
51        .create_response(file_path, true)?
52        .into_response())
53}
54
55#[async_trait::async_trait]
56impl Endpoint for SPAEndpoint {
57    type Output = Response;
58
59    async fn call(&self, req: Request) -> Result<Self::Output> {
60        if req.method() != Method::GET {
61            return Err(StaticFileError::MethodNotAllowed(req.method().clone()).into());
62        }
63
64        let path = req
65            .uri()
66            .path()
67            .trim_start_matches('/')
68            .trim_end_matches('/');
69
70        let path = percent_encoding::percent_decode_str(path)
71            .decode_utf8()
72            .map_err(|_| StaticFileError::InvalidPath)?;
73
74        let mut file_path = self.base.clone();
75        for p in Path::new(&*path) {
76            if p == OsStr::new(".") {
77                continue;
78            } else if p == OsStr::new("..") {
79                file_path.pop();
80            } else {
81                file_path.push(p);
82            }
83        }
84
85        if !self.path_allowed(&file_path) {
86            return Err(StaticFileError::Forbidden(file_path.display().to_string()).into());
87        }
88
89        if file_path.exists() && file_path.is_file() {
90            serve_static(&req, &file_path).await
91        } else if self.is_asset(&file_path) {
92            if file_path.exists() {
93                return Err(StaticFileError::Forbidden(file_path.display().to_string()).into());
94            } else {
95                Err(StaticFileError::NotFound.into())
96            }
97        } else {
98            serve_static(&req, &self.index).await
99        }
100    }
101}