1use anyhow::{anyhow, Context, Result};
48use axum::{
49 body::{Bytes, HttpBody},
50 extract::Request,
51 http::HeaderValue,
52 response::Response,
53 routing::{get_service, Route},
54};
55#[cfg(feature = "openssl")]
56use axum_server::tls_openssl::OpenSSLConfig;
57#[cfg(feature = "rustls")]
58use axum_server::tls_rustls::RustlsConfig;
59use futures_util::future::BoxFuture;
60use http::{
61 header::{self},
62 StatusCode,
63};
64#[cfg(feature = "reverse-proxy")]
65use http::{Method, Uri};
66use log::{debug, error, warn};
67pub use rust_embed;
68use std::{
69 collections::HashMap,
70 convert::Infallible,
71 env::current_exe,
72 fs::{self, create_dir_all},
73 net::SocketAddr,
74 path::PathBuf,
75 sync::Arc,
76};
77use tower::{util::ServiceExt as TowerServiceExt, Layer, Service};
78use tower_http::{
79 services::{ServeDir, ServeFile},
80 set_header::SetResponseHeaderLayer,
81};
82
83pub use axum::*;
84pub mod auth;
85pub mod session;
86pub use axum::debug_handler;
87pub use axum_help::*;
88
89#[derive(Default)]
98pub struct SpaServer<T = ()>
99where
100 T: Clone + Send + Sync + 'static,
101{
102 static_path: Vec<(String, PathBuf)>,
103 port: u16,
104 router: Router,
105 data: Option<T>,
106 forward: Option<String>,
107 release_path: PathBuf,
108 extra_layer: Vec<Box<dyn FnOnce(Router) -> Router>>,
109 host_routers: HashMap<String, Router>,
110}
111
112#[axum::debug_handler]
113#[cfg(feature = "reverse-proxy")]
114async fn forwarded_to_dev(
115 Extension(forward_addr): Extension<String>,
116 uri: Uri,
117 method: Method,
118) -> HttpResult<Response> {
119 use axum::http::Response;
120 use http::uri::Scheme;
121
122 if method == Method::GET {
123 let client = reqwest::Client::builder().no_proxy().build()?;
124 let mut parts = uri.into_parts();
125 parts.authority = Some(forward_addr.parse()?);
126 if parts.scheme.is_none() {
127 parts.scheme = Some(Scheme::HTTP);
128 }
129 let url = Uri::from_parts(parts)?.to_string();
130
131 println!("forward url: {}", url);
132 let response = client.get(url).send().await?;
133 let response: http::Response<_> = response.into();
134 let (parts, body) = response.into_parts();
135 let body = body.as_bytes().map(|b| b.to_vec()).unwrap_or_default();
136
137 let response = Response::from_parts(parts, body.into());
138 return Ok(response);
139 }
140
141 Err(HttpError {
142 message: "Method not allowed".to_string(),
143 status_code: StatusCode::METHOD_NOT_ALLOWED,
144 })
145}
146
147#[cfg(not(feature = "reverse-proxy"))]
148async fn forwarded_to_dev() {
149 unreachable!("reverse-proxy not enabled, should never call forwarded_to_dev")
150}
151
152impl<T> SpaServer<T>
153where
154 T: Clone + Send + Sync + 'static,
155{
156 pub fn new() -> Result<Self> {
158 Ok(Self {
159 static_path: Vec::new(),
160 port: 8080,
161 forward: None,
162 release_path: current_exe()?
163 .parent()
164 .ok_or_else(|| anyhow!("no parent in current_exe"))?
165 .join(format!(".{}_static_files", env!("CARGO_PKG_NAME"))),
166 extra_layer: Vec::new(),
167 host_routers: HashMap::new(),
168 router: Router::new(),
169 data: None,
170 })
171 }
172
173 pub fn data(mut self, data: T) -> Self
177 where
178 T: Clone + Send + Sync + 'static,
179 {
180 self.data = Some(data);
181 self
182 }
183
184 pub fn layer<L, NewResBody>(mut self, layer: L) -> Self
188 where
189 L: Layer<Route> + Clone + Send + 'static,
190 L::Service: Service<Request, Response = Response<NewResBody>, Error = Infallible>
191 + Clone
192 + Send
193 + 'static,
194 <L::Service as Service<Request>>::Future: Send + 'static,
195 NewResBody: HttpBody<Data = Bytes> + Send + 'static,
196 NewResBody::Error: Into<BoxError>,
197 {
198 self.extra_layer.push(Box::new(move |app| app.layer(layer)));
199 self
200 }
201
202 #[cfg(feature = "reverse-proxy")]
206 #[cfg_attr(docsrs, doc(cfg(feature = "reverse-proxy")))]
207 pub fn reverse_proxy(mut self, addr: impl Into<String>) -> Self {
208 self.forward = Some(addr.into());
209 self
210 }
211
212 pub fn release_path(mut self, rp: impl Into<PathBuf>) -> Self {
216 self.release_path = rp.into();
217 self
218 }
219
220 pub async fn run<Root>(self, root: Root) -> Result<()>
222 where
223 Root: SpaStatic,
224 {
225 self.run_raw(Some(root), None).await
226 }
227
228 #[cfg(any(feature = "openssl", feature = "rustls"))]
230 pub async fn run_tls<Root>(self, root: Root, config: HttpsConfig) -> Result<()>
231 where
232 Root: SpaStatic,
233 {
234 self.run_raw(Some(root), Some(config)).await
235 }
236
237 pub async fn run_api(self) -> Result<()> {
239 self.run_raw::<ApiOnly>(None, None).await
240 }
241
242 #[cfg(any(feature = "openssl", feature = "rustls"))]
244 pub async fn run_api_tls(self, config: HttpsConfig) -> Result<()> {
245 self.run_raw::<ApiOnly>(None, Some(config)).await
246 }
247
248 async fn run_raw<Root>(mut self, root: Option<Root>, config: Option<HttpsConfig>) -> Result<()>
250 where
251 Root: SpaStatic,
252 {
253 if let Some(root) = root {
254 let embeded_dir = root.release(self.release_path)?;
255 let index_file = embeded_dir.clone().join("index.html");
256
257 self.router = if let Some(addr) = self.forward {
258 self.router
259 .fallback(forwarded_to_dev)
260 .layer(Extension(addr))
261 } else {
262 self.router.fallback_service(
263 get_service(ServeDir::new(&embeded_dir).fallback(ServeFile::new(index_file)))
264 .layer(Self::add_cache_control())
265 .handle_error(|e: anyhow::Error| async move {
266 (
267 StatusCode::INTERNAL_SERVER_ERROR,
268 format!(
269 "Unhandled internal server error {:?} when serve embeded path {}",
270 e,
271 embeded_dir.display()
272 ),
273 )
274 }),
275 )
276 };
277 }
278
279 for sf in self.static_path {
280 self.router = self.router.nest_service(
281 &sf.0,
282 get_service(ServeDir::new(&sf.1))
283 .layer(Self::add_cache_control())
284 .handle_error(|e: anyhow::Error| async move {
285 (
286 StatusCode::INTERNAL_SERVER_ERROR,
287 format!(
288 "Unhandled internal server error {:?} when serve static path {}",
289 e,
290 sf.1.display()
291 ),
292 )
293 }),
294 )
295 }
296
297 self.router = self
298 .router
299 .layer(MatchHostLayer::new(Arc::new(self.host_routers.clone())));
300
301 if let Some(data) = self.data {
302 self.router = self.router.layer(Extension(data));
303 }
304
305 for layer in self.extra_layer {
306 self.router = layer(self.router)
307 }
308
309 let addr = format!("0.0.0.0:{}", self.port).parse()?;
310 #[allow(unused_variables)]
311 if let Some(config) = config {
312 #[cfg(all(feature = "openssl", feature = "rustls"))]
313 compile_error!("Feature openssl and Feature rustls can not be enabled together");
314
315 #[cfg(any(feature = "openssl", feature = "rustls"))]
316 {
317 #[cfg(feature = "rustls")]
318 {
319 let certificate = std::fs::read(config.certificate)?;
320 let private_key = std::fs::read(config.private_key)?;
321 axum_server::bind_rustls(
322 addr,
323 RustlsConfig::from_pem(certificate, private_key).await?,
324 )
325 }
326 #[cfg(feature = "openssl")]
327 {
328 axum_server::bind_openssl(
329 addr,
330 OpenSSLConfig::from_pem_file(config.certificate, config.private_key)
331 .context("openssl load pem file error")?,
332 )
333 }
334 }
335 .serve(
336 self.router
337 .into_make_service_with_connect_info::<SocketAddr>(),
338 )
339 .await?;
340 } else {
341 axum_server::bind(addr)
342 .serve(
343 self.router
344 .into_make_service_with_connect_info::<SocketAddr>(),
345 )
346 .await
347 .context("serve server error")?;
348 }
349
350 Ok(())
351 }
352
353 pub fn route(mut self, path: impl AsRef<str>, router: Router) -> Self {
356 self.router = self.router.nest(path.as_ref(), router);
357 self
358 }
359
360 pub fn port(mut self, port: u16) -> Self {
363 self.port = port;
364 self
365 }
366
367 pub fn static_path(mut self, path: impl Into<String>, dir: impl Into<PathBuf>) -> Self {
371 self.static_path.push((path.into(), dir.into()));
372 self
373 }
374
375 pub fn host_router(mut self, host: impl Into<String>, router: Router) -> Self {
378 self.host_routers.insert(host.into(), router);
379 self
380 }
381
382 fn add_cache_control() -> SetResponseHeaderLayer<HeaderValue> {
383 SetResponseHeaderLayer::if_not_present(
384 header::CACHE_CONTROL,
385 HeaderValue::from_static("max-age=300"),
386 )
387 }
388}
389
390#[derive(Clone)]
391struct MatchHostLayer {
392 host_routers: Arc<HashMap<String, Router>>,
393}
394
395impl<S> Layer<S> for MatchHostLayer {
396 type Service = MatchHost<S>;
397
398 fn layer(&self, inner: S) -> Self::Service {
399 MatchHost {
400 inner,
401 host_routers: self.host_routers.clone(),
402 }
403 }
404}
405
406impl MatchHostLayer {
407 pub fn new(host_routers: Arc<HashMap<String, Router>>) -> Self {
408 Self { host_routers }
409 }
410}
411
412#[derive(Clone)]
413struct MatchHost<S> {
414 inner: S,
415 host_routers: Arc<HashMap<String, Router>>,
416}
417
418impl<S> Service<Request> for MatchHost<S>
419where
420 S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
421 S::Future: Send + 'static,
422 S::Error: Into<Infallible>,
423{
424 type Response = S::Response;
425 type Error = S::Error;
426 type Future = BoxFuture<'static, Result<S::Response, S::Error>>;
427
428 fn poll_ready(
429 &mut self,
430 cx: &mut std::task::Context<'_>,
431 ) -> std::task::Poll<Result<(), Self::Error>> {
432 self.inner.poll_ready(cx)
433 }
434
435 fn call(&mut self, req: Request) -> Self::Future {
436 let host_routers = self.host_routers.clone();
437 let mut srv = self.inner.clone();
438 Box::pin(async move {
439 let hostname = req
440 .headers()
441 .get(header::HOST)
442 .and_then(|h| h.to_str().ok())
443 .unwrap_or_default();
444
445 if let Some((_, router)) = host_routers.iter().find(|(k, _v)| hostname.ends_with(*k)) {
446 router.clone().oneshot(req).await
447 } else {
448 srv.call(req).await
449 }
450 })
451 }
452}
453
454pub struct HttpsConfig {
455 pub certificate: PathBuf,
456 pub private_key: PathBuf,
457}
458
459#[macro_export]
470macro_rules! embed_https_pems {
471 ($path: literal) => {
472 #[derive(spa_rs::rust_embed::RustEmbed)]
473 #[crate_path = "spa_rs::rust_embed"]
474 #[folder = $path]
475 struct HttpsPems;
476 };
477
478 () => {{
479 let https_config = || -> anyhow::Result<spa_rs::HttpsConfig> {
480 let mut base_path = std::env::temp_dir().join(format!(
481 "{}_{}",
482 env!("CARGO_PKG_NAME"),
483 env!("CARGO_PKG_VERSION")
484 ));
485 let _ = std::fs::create_dir_all(&base_path);
486 let mut cert_path = None;
487 let mut key_path = None;
488 for file in HttpsPems::iter() {
489 if let Some(f) = HttpsPems::get(&file) {
490 if file == "key.pem" {
491 key_path = Some(base_path.join("key.pem"));
492 std::fs::write(key_path.as_ref().unwrap(), &f.data)?;
493 }
494
495 if file == "cert.pem" {
496 cert_path = Some(base_path.join("cert.pem"));
497 std::fs::write(cert_path.as_ref().unwrap(), &f.data)?;
498 }
499 }
500 }
501
502 if cert_path.is_none() || key_path.is_none() {
503 anyhow::bail!("invalid ssl cert or key embed file");
504 }
505
506 Ok(spa_rs::HttpsConfig {
507 certificate: cert_path.unwrap(),
508 private_key: key_path.unwrap(),
509 })
510 };
511 https_config()
512 }};
513}
514
515#[macro_export]
518macro_rules! spa_server_root {
519 ($root: literal) => {
520 #[derive(spa_rs::rust_embed::RustEmbed)]
521 #[crate_path = "spa_rs::rust_embed"]
522 #[folder = $root]
523 struct StaticFiles;
524
525 impl spa_rs::SpaStatic for StaticFiles {}
526 };
527 () => {
528 StaticFiles
529 };
530}
531
532pub trait SpaStatic: rust_embed::RustEmbed {
535 fn release(&self, release_path: PathBuf) -> Result<PathBuf> {
536 let target_dir = release_path;
537 if !target_dir.exists() {
538 create_dir_all(&target_dir)?;
539 }
540
541 for file in Self::iter() {
542 match Self::get(&file) {
543 Some(f) => {
544 if let Some(p) = std::path::Path::new(file.as_ref()).parent() {
545 let parent_dir = target_dir.join(p);
546 create_dir_all(parent_dir)?;
547 }
548
549 let path = target_dir.join(file.as_ref());
550 debug!("release static file: {}", path.display());
551 if let Err(e) = fs::write(path, f.data) {
552 error!("static file {} write error: {:?}", file, e);
553 }
554 }
555 None => warn!("static file {} not found", file),
556 }
557 }
558
559 Ok(target_dir)
560 }
561}
562
563impl SpaStatic for ApiOnly {}
564impl rust_embed::RustEmbed for ApiOnly {
565 fn get(_file_path: &str) -> Option<rust_embed::EmbeddedFile> {
566 unreachable!()
567 }
568
569 fn iter() -> rust_embed::Filenames {
570 unreachable!()
571 }
572}
573
574struct ApiOnly;