1
2use axum::Router;
11use axum::http::{HeaderValue, Method, StatusCode};
12use axum::response::{IntoResponse, Json};
13use tower_http::cors::{CorsLayer, AllowOrigin};
14use tower_http::compression::CompressionLayer;
15use tower_http::services::ServeDir;
16use std::collections::HashMap;
17use std::net::SocketAddr;
18use std::sync::Arc;
19use std::time::Instant;
20use parking_lot::RwLock;
21use tracing::info;
22
23use crate::router::AlunRouter;
24use crate::middleware as mw;
25use alun_core::{PluginManager, Result};
26use alun_core::api::{codes, Res};
27use alun_config::{AppConfig, ConfigManager};
28use crate::resources::*;
29
30#[derive(Clone)]
32pub struct AppSettings {
33 pub config_path: Option<String>,
35 pub gen_config_only: bool,
37 pub print_config: bool,
39}
40
41pub struct App {
54 router: Option<AlunRouter>,
56 plugins: PluginManager,
58 settings: AppSettings,
60 config_mgr: Option<Arc<ConfigManager>>,
62 prefix: String,
64 rate_limit_store: Arc<RwLock<HashMap<String, mw::IpWindow>>>,
66 custom_middleware_hook: Option<Box<dyn FnOnce(Router) -> Router + Send>>,
68 startup_hook: Option<Box<dyn FnOnce() -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send>>,
70}
71
72impl App {
73 pub fn new() -> Result<Self> {
86 Self::from_config_dir("config")
87 }
88
89 pub fn from_config() -> Result<Self> {
91 Self::new()
92 }
93
94 pub fn from_config_dir(dir: &str) -> Result<Self> {
96 let cm = Arc::new(ConfigManager::load(Some(dir)));
97 Self::with_config_manager(cm)
98 }
99
100 pub fn with_config(cfg: AppConfig) -> Result<Self> {
102 let cm = ConfigManager {
103 static_config: cfg,
104 dynamic: parking_lot::RwLock::new(HashMap::new()),
105 };
106 Self::with_config_manager(Arc::new(cm))
107 }
108
109 pub fn with_config_manager(cm: Arc<ConfigManager>) -> Result<Self> {
113 let cfg = cm.get();
114 alun_log::init(&cfg.log);
115
116 let prefix = cfg.router.prefix.clone();
117
118 Ok(Self {
119 router: Some(AlunRouter::new()),
120 plugins: PluginManager::new(),
121 settings: AppSettings {
122 config_path: Some("config".into()),
123 gen_config_only: false,
124 print_config: false,
125 },
126 config_mgr: Some(cm),
127 prefix,
128 rate_limit_store: Arc::new(RwLock::new(HashMap::new())),
129 custom_middleware_hook: None,
130 startup_hook: None,
131 })
132 }
133
134 pub fn parse_cli(mut self) -> Self {
138 let (gen_config, print_config) = alun_config::env::parse_args();
139 self.settings.gen_config_only = gen_config;
140 self.settings.print_config = print_config;
141 self
142 }
143
144 pub fn get<H, T>(mut self, path: &str, handler: H) -> Self
148 where
149 H: axum::handler::Handler<T, ()>,
150 T: 'static,
151 {
152 if let Some(ref mut r) = self.router {
153 r.add_get(path, handler);
154 }
155 self
156 }
157
158 pub fn post<H, T>(mut self, path: &str, handler: H) -> Self
160 where
161 H: axum::handler::Handler<T, ()>,
162 T: 'static,
163 {
164 if let Some(ref mut r) = self.router {
165 r.add_post(path, handler);
166 }
167 self
168 }
169
170 pub fn put<H, T>(mut self, path: &str, handler: H) -> Self
172 where
173 H: axum::handler::Handler<T, ()>,
174 T: 'static,
175 {
176 if let Some(ref mut r) = self.router {
177 r.add_put(path, handler);
178 }
179 self
180 }
181
182 pub fn delete<H, T>(mut self, path: &str, handler: H) -> Self
184 where
185 H: axum::handler::Handler<T, ()>,
186 T: 'static,
187 {
188 if let Some(ref mut r) = self.router {
189 r.add_delete(path, handler);
190 }
191 self
192 }
193
194 pub fn route<H, T>(mut self, method: &str, path: &str, handler: H) -> Self
205 where
206 H: axum::handler::Handler<T, ()>,
207 T: 'static,
208 {
209 if let Some(ref mut r) = self.router {
210 r.add_route(method, path, handler);
211 }
212 self
213 }
214
215 pub fn group(mut self, prefix: &str, f: impl FnOnce(Self) -> Self) -> Self {
229 let sub = f(Self {
230 router: Some(AlunRouter::new()),
231 plugins: PluginManager::new(),
232 settings: AppSettings {
233 config_path: None,
234 gen_config_only: false,
235 print_config: false,
236 },
237 config_mgr: None,
238 prefix: String::new(),
239 rate_limit_store: Arc::new(RwLock::new(HashMap::new())),
240 custom_middleware_hook: None,
241 startup_hook: None,
242 });
243 if let (Some(ref mut r), Some(sub_r)) = (self.router.as_mut(), sub.router) {
244 r.merge(prefix, sub_r);
245 }
246 self
247 }
248
249 pub fn scan(mut self) -> Self {
265 for register in crate::ROUTES {
266 if let Some(ref mut r) = self.router {
267 register(r);
268 }
269 }
270 self
271 }
272
273 pub fn merge(mut self, prefix: &str, sub: AlunRouter) -> Self {
275 if let Some(ref mut r) = self.router {
276 r.merge(prefix, sub);
277 }
278 self
279 }
280
281 pub fn with_permission<H, T>(
294 mut self, method: &str, path: &str, handler: H, permission: &str,
295 ) -> Self
296 where
297 H: axum::handler::Handler<T, ()>,
298 T: 'static,
299 {
300 let perm = permission.to_string();
301 if let Some(ref mut r) = self.router {
302 let wrap = move |mr: axum::routing::MethodRouter<()>| {
303 mr.route_layer(mw::RequirePermissionLayer::any(vec![perm]))
304 };
305 match method.to_uppercase().as_str() {
306 "GET" => r.add_get_with_layer(path, handler, wrap),
307 "POST" => r.add_post_with_layer(path, handler, wrap),
308 "PUT" => r.add_put_with_layer(path, handler, wrap),
309 "DELETE" => r.add_delete_with_layer(path, handler, wrap),
310 _ => r.add_get_with_layer(path, handler, wrap),
311 };
312 }
313 self
314 }
315
316 pub fn with_role<H, T>(
327 mut self, method: &str, path: &str, handler: H, role: &str,
328 ) -> Self
329 where
330 H: axum::handler::Handler<T, ()>,
331 T: 'static,
332 {
333 let rl = role.to_string();
334 if let Some(ref mut r) = self.router {
335 let wrap = move |mr: axum::routing::MethodRouter<()>| {
336 mr.route_layer(mw::RequireRoleLayer::any(vec![rl]))
337 };
338 match method.to_uppercase().as_str() {
339 "GET" => r.add_get_with_layer(path, handler, wrap),
340 "POST" => r.add_post_with_layer(path, handler, wrap),
341 "PUT" => r.add_put_with_layer(path, handler, wrap),
342 "DELETE" => r.add_delete_with_layer(path, handler, wrap),
343 _ => r.add_get_with_layer(path, handler, wrap),
344 };
345 }
346 self
347 }
348
349 pub fn plugin<P: alun_core::Plugin + 'static>(mut self, plugin: P) -> Self {
355 self.plugins = self.plugins.add(plugin);
356 self
357 }
358
359 pub fn on_startup<F, Fut>(mut self, hook: F) -> Self
375 where
376 F: FnOnce() -> Fut + Send + 'static,
377 Fut: std::future::Future<Output = ()> + Send + 'static,
378 {
379 self.startup_hook = Some(Box::new(|| Box::pin(hook())));
380 self
381 }
382
383 pub fn with_middleware_hook<F>(mut self, hook: F) -> Self
399 where
400 F: FnOnce(Router) -> Router + Send + 'static,
401 {
402 self.custom_middleware_hook = Some(Box::new(hook));
403 self
404 }
405
406 pub async fn start(mut self) -> Result<()> {
408 let startup_start = Instant::now();
409
410 if self.settings.gen_config_only {
411 let dir = self.settings.config_path.as_deref().unwrap_or("config");
412 ConfigManager::generate_default(dir)?;
413 return Ok(());
414 }
415
416 if self.config_mgr.is_none() {
417 let _ = tracing_subscriber::fmt()
418 .with_env_filter(tracing_subscriber::EnvFilter::try_from_default_env()
419 .unwrap_or_else(|_| "alun=info".into()))
420 .try_init();
421 }
422
423 if self.settings.print_config {
424 if let Some(ref cm) = self.config_mgr {
425 if let Ok(toml_str) = toml::to_string_pretty(cm.get()) {
426 println!("{}", toml_str);
427 }
428 }
429 }
430
431 if let Some(ref cm) = self.config_mgr {
433 Self::init_global_resources(cm).await?;
434 }
435
436 if let Some(hook) = self.startup_hook.take() {
438 (hook)().await;
439 }
440
441 self.plugins.check_duplicate_names()
443 .map_err(alun_core::Error::Config)?;
444 self.plugins.start_all().await?;
445
446 let router = self.router.take().unwrap_or_default();
448 let mut axum_router: Router = router.into_axum();
449 axum_router = self.build_middleware_chain(axum_router);
450
451 if let Some(ref cm) = self.config_mgr {
453 let cfg = cm.get();
454 if cfg.static_files.enabled {
455 let static_path = cfg.static_files.path.clone();
456 std::fs::create_dir_all(&static_path).ok();
457 info!("静态文件服务就绪 path={}", static_path);
458 axum_router = axum_router.fallback_service(ServeDir::new(&static_path));
459 } else if cfg.router.not_found.enabled {
460 axum_router = axum_router.fallback(Self::handle_not_found);
461 }
462 }
463
464 if let Some(hook) = self.custom_middleware_hook.take() {
465 axum_router = hook(axum_router);
466 }
467
468 let bind_addr = self.config_mgr
469 .as_ref()
470 .map(|cm| cm.get().server.listen.clone())
471 .unwrap_or_else(|| "0.0.0.0:0".to_string());
472
473 let socket_addr = parse_addr(&bind_addr)?;
474 let display_addr = resolve_display_addr(socket_addr);
475 let app_name = self.config_mgr.as_ref().map(|cm| cm.get().app_name.as_str()).unwrap_or("Alun");
476 info!("{} 启动 -> http://{}", app_name, display_addr);
477 if let Some(cm) = &self.config_mgr {
478 info!(
479 " profile={}, request_id={} log={} cors={} compression={} rate_limit={} jwt_auth={} static_files={} not_found={}",
480 cm.get().profile,
481 cm.get().middleware.request_id,
482 cm.get().middleware.request_log,
483 cm.get().middleware.cors.enabled,
484 cm.get().middleware.compression.enabled,
485 cm.get().middleware.rate_limit.enabled,
486 cm.get().middleware.auth.enabled,
487 cm.get().static_files.enabled,
488 cm.get().router.not_found.enabled,
489 );
490 }
491
492 let startup_ms = startup_start.elapsed().as_millis();
493 info!("{} 启动完成, 耗时 {}ms", app_name, startup_ms);
494
495 let listener = tokio::net::TcpListener::bind(socket_addr).await?;
496 axum::serve(listener, axum_router.into_make_service_with_connect_info::<SocketAddr>())
497 .with_graceful_shutdown(shutdown_signal())
498 .await?;
499
500 self.plugins.stop_all().await;
501 Ok(())
502 }
503
504 pub async fn serve(self, addr: impl Into<String>) -> Result<()> {
506 let mut s = self;
507 let addr_str = addr.into();
508 if s.config_mgr.is_none() {
509 let default_cfg = AppConfig {
510 server: alun_config::ServerConfig {
511 listen: addr_str.clone(),
512 ..Default::default()
513 },
514 ..Default::default()
515 };
516 s.config_mgr = Some(Arc::new(ConfigManager {
517 static_config: default_cfg,
518 dynamic: parking_lot::RwLock::new(HashMap::new()),
519 }));
520 } else if let Some(ref cm) = s.config_mgr {
521 let mut cfg = cm.get().clone();
522 cfg.server.listen = addr_str.clone();
523 s.config_mgr = Some(Arc::new(ConfigManager {
524 static_config: cfg,
525 dynamic: parking_lot::RwLock::new(HashMap::new()),
526 }));
527 }
528 s.start().await
529 }
530
531 async fn init_global_resources(cm: &Arc<ConfigManager>) -> Result<()> {
533 set_config(cm.clone()).map_err(|e| alun_core::Error::Config(e.to_string()))?;
534 let cfg = cfg();
535
536 #[cfg(feature = "db")]
537 if cfg.database.enabled {
538 match alun_db::factory::create_db(&cfg.database).await {
539 Ok(db) => {
540 info!("数据库连接成功");
541
542 if cfg.database.migration.enabled && cfg.database.migration.auto_migrate {
543 let migrator = alun_db::migrate::Migrator::new(db.clone(), cfg.database.migration.clone());
544 match migrator.run().await {
545 Ok(records) => info!("数据库迁移完成: {:?}", records.iter().map(|r| &r.version).collect::<Vec<_>>()),
546 Err(e) => {
547 tracing::error!("数据库迁移失败: {}", e);
548 return Err(alun_core::Error::Config(format!("数据库迁移失败: {}", e)));
549 }
550 }
551 }
552
553 set_db(db).map_err(|e| alun_core::Error::Config(e.to_string()))?;
554 }
555 Err(e) => {
556 tracing::error!("数据库连接失败: {}", e);
557 return Err(alun_core::Error::Config(format!("数据库连接失败: {}", e)));
558 }
559 }
560 }
561
562 #[cfg(feature = "cache")]
563 if cfg.cache.r#type != "local" || cfg.cache.max_capacity > 0 {
564 match alun_cache::create_cache(&cfg.cache, &cfg.redis).await {
565 Ok(c) => {
566 set_cache(c).map_err(|e| alun_core::Error::Config(e.to_string()))?;
567 }
568 Err(e) => {
569 tracing::warn!("缓存初始化失败: {},将不使用缓存", e);
570 }
571 }
572 }
573
574 #[cfg(feature = "template")]
575 {
576 match alun_template::TemplateEngine::from_dir(&cfg.template.path) {
577 Ok(engine) => {
578 info!("模板引擎就绪 path={}", cfg.template.path);
579 set_template(engine).map_err(|e| alun_core::Error::Config(e.to_string()))?;
580 }
581 Err(e) => {
582 tracing::warn!("模板引擎初始化失败: {},将使用空引擎", e);
583 let _ = set_template(alun_template::TemplateEngine::new());
584 }
585 }
586 }
587
588 {
590 let upload_path = &cfg.upload.path;
591 std::fs::create_dir_all(upload_path).map_err(|e| {
592 alun_core::Error::Config(format!("创建上传目录失败 '{}': {}", upload_path, e))
593 })?;
594 set_upload_path(upload_path.clone())
595 .map_err(|e| alun_core::Error::Config(e.to_string()))?;
596 info!("上传目录就绪 path={} max_size_mb={}", upload_path, cfg.upload.max_size_mb);
597 }
598
599 {
601 let download_path = &cfg.download.path;
602 std::fs::create_dir_all(download_path).map_err(|e| {
603 alun_core::Error::Config(format!("创建下载目录失败 '{}': {}", download_path, e))
604 })?;
605 set_download_path(download_path.clone())
606 .map_err(|e| alun_core::Error::Config(e.to_string()))?;
607 info!("下载目录就绪 path={}", download_path);
608 }
609
610 Ok(())
611 }
612
613 async fn handle_not_found() -> impl IntoResponse {
617 let msg = cfg().router.not_found.message.clone();
618 (StatusCode::NOT_FOUND, Json(Res::<()>::fail(codes::NOT_FOUND, msg)))
619 }
620
621 fn build_middleware_chain(
623 &self,
624 mut router: Router,
625 ) -> Router {
626 if let Some(ref cm) = self.config_mgr {
627 let cfg = cm.get();
628
629 if cfg.middleware.security_headers.enabled {
631 router = router.layer(mw::SecurityHeadersLayer::new(
632 cfg.middleware.security_headers.clone(),
633 ));
634 }
635
636 if cfg.middleware.request_log {
637 let log_cfg = &cfg.middleware.request_log_config;
638 let prefix_excluded: Vec<String> = log_cfg.exclude_paths
639 .iter().map(|p| format!("{}{}", self.prefix, p)).collect();
640 let log_layer = mw::RequestLogLayer {
641 exclude_paths: prefix_excluded,
642 log_duration: log_cfg.log_duration,
643 };
644 router = router.layer(log_layer);
645 }
646
647 if cfg.middleware.request_id {
648 router = router.layer(mw::RequestIdLayer);
649 }
650
651 if cfg.middleware.cors.enabled {
652 let mut cors = CorsLayer::new();
653 if !cfg.middleware.cors.allow_origins.is_empty() {
654 let origins: Vec<HeaderValue> = cfg.middleware.cors.allow_origins
655 .iter().filter_map(|o| o.parse().ok()).collect();
656 cors = cors.allow_origin(AllowOrigin::list(origins));
657 } else {
658 cors = cors.allow_origin(AllowOrigin::any());
659 }
660 if !cfg.middleware.cors.allow_methods.is_empty() {
661 let methods: Vec<Method> = cfg.middleware.cors.allow_methods
662 .iter().filter_map(|m| m.parse().ok()).collect();
663 cors = cors.allow_methods(methods);
664 }
665 if !cfg.middleware.cors.allow_headers.is_empty() {
666 let headers: Vec<axum::http::HeaderName> = cfg.middleware.cors.allow_headers
667 .iter().filter_map(|h| h.parse().ok()).collect();
668 cors = cors.allow_headers(headers);
669 } else {
670 cors = cors.allow_headers(tower_http::cors::AllowHeaders::any());
671 }
672 if cfg.middleware.cors.allow_credentials {
673 cors = cors.allow_credentials(true);
674 }
675 cors = cors.max_age(std::time::Duration::from_secs(cfg.middleware.cors.max_age_secs));
676 router = router.layer(cors);
677 }
678
679 if cfg.middleware.compression.enabled {
680 router = router.layer(CompressionLayer::new().gzip(true));
681 }
682
683 if cfg.middleware.rate_limit.enabled {
684 let rl_layer = mw::RateLimitLayer {
685 requests_per_window: cfg.middleware.rate_limit.requests_per_window,
686 window_secs: cfg.middleware.rate_limit.window_secs,
687 store: self.rate_limit_store.clone(),
688 };
689 router = router.layer(rl_layer);
690 }
691
692 let mut perm_layer = mw::PermissionCheckLayer::from_config(&cfg.middleware.permission.rules);
696 perm_layer = perm_layer.with_macro_rules(&crate::PERMISSION_ROUTES);
697 if cfg.middleware.permission.enabled && perm_layer.has_rules() {
698 router = router.layer(perm_layer);
699 }
700
701 if cfg.middleware.auth.enabled && !cfg.middleware.auth.jwt_secret.is_empty() {
702 let mut ignore: Vec<String> = cfg.middleware.auth.ignore_paths
703 .iter().map(|p| format!("{}{}", self.prefix, p)).collect();
704
705 for def in crate::NO_AUTH_ROUTES {
707 let path_with_prefix = format!("{}{}", self.prefix, def.path);
708 if !ignore.contains(&path_with_prefix) {
709 ignore.push(path_with_prefix);
710 }
711 }
712
713 #[cfg(feature = "cache")]
714 let cache = try_cache().cloned();
715 let auth_layer = mw::AuthLayer {
716 jwt_secret: cfg.middleware.auth.jwt_secret.clone(),
717 ignore_paths: ignore,
718 #[cfg(feature = "cache")]
719 cache,
720 };
721 router = router.layer(auth_layer);
722 }
723 }
724 router
725 }
726}
727
728impl Default for App {
729 fn default() -> Self {
730 Self {
731 router: Some(AlunRouter::new()),
732 plugins: PluginManager::new(),
733 settings: AppSettings {
734 config_path: Some("config".into()),
735 gen_config_only: false,
736 print_config: false,
737 },
738 config_mgr: None,
739 prefix: String::new(),
740 rate_limit_store: Arc::new(RwLock::new(HashMap::new())),
741 custom_middleware_hook: None,
742 startup_hook: None,
743 }
744 }
745}
746
747fn parse_addr(addr: &str) -> alun_core::Result<SocketAddr> {
750 let addr = match addr {
751 a if a.starts_with(':') => format!("0.0.0.0{}", a),
752 a if !a.contains(':') => format!("0.0.0.0:{}", a),
753 a => a.to_string(),
754 };
755 addr.parse()
756 .map_err(|e| alun_core::Error::Config(format!("无效地址 '{}': {}", addr, e)))
757}
758
759fn resolve_display_addr(addr: SocketAddr) -> String {
760 if addr.ip().is_unspecified() {
761 format!("127.0.0.1:{}", addr.port())
762 } else {
763 addr.to_string()
764 }
765}
766
767async fn shutdown_signal() {
768 tokio::signal::ctrl_c().await.expect("Ctrl-C 注册失败");
769 info!("收到关闭信号,优雅退出中...");
770}
771