1use std::marker::PhantomData;
2
3use axum::body::{boxed, Body, Full, HttpBody};
4use axum::handler::HandlerWithoutStateExt;
5use axum::http::{header, StatusCode, Uri};
6use axum::response::{IntoResponse, Response};
7use axum::routing::get_service;
8use axum::Router;
9use mime_guess::mime;
10use rust_embed::RustEmbed;
11
12static INDEX_PATH: &str = "index.html";
13
14pub struct SpaRouter<A, B = Body, T = (), S = ()>
15where
16 A: RustEmbed,
17{
18 path: &'static str,
19 _assets: PhantomData<A>,
20 _marker: PhantomData<fn() -> (B, T, S)>,
21}
22
23impl<A, B, T, S> SpaRouter<A, B, T, S>
24where
25 A: RustEmbed + 'static,
26{
27 pub fn new(path: &'static str) -> Self {
28 Self {
29 path,
30 _assets: Default::default(),
31 _marker: Default::default(),
32 }
33 }
34}
35
36impl<A, B, T, S> From<SpaRouter<A, B, T, S>> for Router<S, B>
37where
38 B: HttpBody + Send + 'static,
39 T: 'static,
40 A: RustEmbed + 'static,
41 S: Clone + Send + Sync + 'static,
42{
43 fn from(spa: SpaRouter<A, B, T, S>) -> Self {
44 Router::new()
45 .nest_service(spa.path, get_service(assets_handler::<A>.into_service()))
46 .fallback_service(get_service(serve_index::<A>.into_service()))
47 }
48}
49async fn serve_asset<A: RustEmbed>(path: &str) -> Response {
50
51 if let Some(index) = A::get(path).or_else(|| A::get(INDEX_PATH)) {
52 let body = boxed(Full::from(index.data));
53 let mime = mime_guess::from_path(path).first_or(mime::TEXT_HTML);
54 let etag = base64::encode(index.metadata.sha256_hash());
55 Response::builder()
56 .header(header::CONTENT_TYPE, mime.as_ref())
57 .header(header::ETAG, etag)
58 .body(body)
59 .unwrap_or_else(|_| not_found())
60 } else {
61 not_found()
62 }
63}
64
65async fn assets_handler<A: RustEmbed>(uri: Uri) -> Response {
66 if uri.path() == "/" {
67 serve_index::<A>().await
68 } else {
69 let path = uri.path().trim_start_matches('/');
70 serve_asset::<A>(path).await
71 }
72}
73
74async fn serve_index<A: RustEmbed>() -> Response {
75 serve_asset::<A>(INDEX_PATH).await
76}
77
78fn not_found() -> Response {
79 (StatusCode::NOT_FOUND, "Not found").into_response()
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85 use axum::response::Redirect;
86 use axum::routing::get;
87 use axum_test_helper::{TestClient, TestResponse};
88
89 #[derive(RustEmbed)]
90 #[folder = "fixture/"]
91 struct TestAssets;
92
93 #[derive(RustEmbed)]
94 #[folder = "fixture-coordinator/"]
95 struct TestAssetsCoordinator;
96
97 #[tokio::test]
98 async fn rust_embed_as_file_provider() {
99 let resp = serve_index::<TestAssets>().await;
100 assert_eq!(200, resp.status())
101 }
102
103 #[tokio::test]
104 async fn basic() {
105 let app = Router::new()
106 .route("/foo", get(|| async { "GET /foo" }))
107 .merge(SpaRouter::new("/") as SpaRouter<TestAssets>);
108 let client = TestClient::new(app);
109
110 let res = client.get("/").send().await;
111 assert_eq!(res.status(), StatusCode::OK);
112 assert!(res.headers().get(header::ETAG).is_some());
113 assert_eq!(
114 res.headers().get(header::CONTENT_TYPE).unwrap().as_bytes(),
115 b"text/html"
116 );
117 assert_eq!(res.text().await, "<h1>Hello, World!</h1>\n");
118
119 let res = client.get("/some/random/path").send().await;
120 assert_eq!(res.status(), StatusCode::OK);
121 assert!(res.headers().get(header::ETAG).is_some());
122 assert_eq!(res.text().await, "<h1>Hello, World!</h1>\n");
123
124 let res = client.get("/assets/script.js").send().await;
125 assert_eq!(res.status(), StatusCode::OK);
126 assert!(res.headers().get(header::ETAG).is_some());
127 assert_eq!(
128 res.headers().get(header::CONTENT_TYPE).unwrap().as_bytes(),
129 b"application/javascript"
130 );
131 assert_eq!(res.text().await, "console.log('hi')\n");
132
133 let res = client.get("/foo").send().await;
134 assert_eq!(res.status(), StatusCode::OK);
135 assert_eq!(res.text().await, "GET /foo");
136
137 let res = client.get("/assets/doesnt_exist").send().await;
138 assert_eq!(res.status(), StatusCode::OK);
139 assert!(res.headers().get(header::ETAG).is_some());
140 assert_eq!(res.text().await, "<h1>Hello, World!</h1>\n");
141 }
142
143 #[tokio::test]
144 async fn coordinator_routing() {
145 macro_rules! test_index {
146 ($res:ident) => {
147 assert_eq!($res.status(), StatusCode::OK);
148 assert!($res.headers().get(header::ETAG).is_some());
149 assert_eq!(
150 $res.headers().get(header::CONTENT_TYPE).unwrap(),
151 "text/html"
152 );
153 };
154 } {
155
156 }
157 let app = Router::new()
158 .route("/api", get(|| async { "OK" }))
159 .route("/", get(|| async { Redirect::permanent("/ui/") }))
160 .merge(SpaRouter::new("/ui") as SpaRouter<TestAssetsCoordinator>);
161
162 let client = TestClient::new(app);
163
164 let res = client.get("/").send().await;
173 assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
174
175 let res = client.get("/ui").send().await;
176 test_index!(res);
177
178
179 let res = client.get("/ui/").send().await;
180 test_index!(res);
181
182 let res = client.get("/ui/webui-6d04b57e86bad3ca.js").send().await;
183 assert_eq!(res.status(), StatusCode::OK);
184 assert!(res.headers().get(header::ETAG).is_some());
185 assert_eq!(
186 res.headers().get(header::CONTENT_TYPE).unwrap().as_bytes(),
187 b"application/javascript"
188 );
189
190 let res = client.get("/ui/auth/sign_in").send().await;
191 test_index!(res);
192
193
194 let res = client.get("/ui/img/logo_dark.png").send().await;
195 assert!(res.headers().get(header::ETAG).is_some());
196 assert_eq!(
197 res.headers().get(header::CONTENT_TYPE).unwrap().as_bytes(),
198 b"image/png"
199 );
200
201 let res = client.get("/api").send().await;
202 assert_eq!(res.status(), StatusCode::OK);
203 assert_eq!(res.text().await, "OK");
204 }
205}