use axum::Router;
use axum::http::{HeaderValue, Method, StatusCode};
use axum::response::{IntoResponse, Json};
use tower_http::cors::{CorsLayer, AllowOrigin};
use tower_http::compression::CompressionLayer;
use tower_http::services::ServeDir;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Instant;
use parking_lot::RwLock;
use tracing::info;
use crate::router::AlunRouter;
use crate::middleware as mw;
use alun_core::{PluginManager, Result};
use alun_core::api::{codes, Res};
use alun_config::{AppConfig, ConfigManager};
use crate::resources::*;
#[derive(Clone)]
pub struct AppSettings {
pub config_path: Option<String>,
pub gen_config_only: bool,
pub print_config: bool,
}
pub struct App {
router: Option<AlunRouter>,
plugins: PluginManager,
settings: AppSettings,
config_mgr: Option<Arc<ConfigManager>>,
prefix: String,
rate_limit_store: Arc<RwLock<HashMap<String, mw::IpWindow>>>,
custom_middleware_hook: Option<Box<dyn FnOnce(Router) -> Router + Send>>,
startup_hook: Option<Box<dyn FnOnce() -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send>>,
}
impl App {
pub fn new() -> Result<Self> {
Self::from_config_dir("config")
}
pub fn from_config() -> Result<Self> {
Self::new()
}
pub fn from_config_dir(dir: &str) -> Result<Self> {
let cm = Arc::new(ConfigManager::load(Some(dir)));
Self::with_config_manager(cm)
}
pub fn with_config(cfg: AppConfig) -> Result<Self> {
let cm = ConfigManager {
static_config: cfg,
dynamic: parking_lot::RwLock::new(HashMap::new()),
};
Self::with_config_manager(Arc::new(cm))
}
pub fn with_config_manager(cm: Arc<ConfigManager>) -> Result<Self> {
let cfg = cm.get();
alun_log::init(&cfg.log);
let prefix = cfg.router.prefix.clone();
Ok(Self {
router: Some(AlunRouter::new()),
plugins: PluginManager::new(),
settings: AppSettings {
config_path: Some("config".into()),
gen_config_only: false,
print_config: false,
},
config_mgr: Some(cm),
prefix,
rate_limit_store: Arc::new(RwLock::new(HashMap::new())),
custom_middleware_hook: None,
startup_hook: None,
})
}
pub fn parse_cli(mut self) -> Self {
let (gen_config, print_config) = alun_config::env::parse_args();
self.settings.gen_config_only = gen_config;
self.settings.print_config = print_config;
self
}
pub fn get<H, T>(mut self, path: &str, handler: H) -> Self
where
H: axum::handler::Handler<T, ()>,
T: 'static,
{
if let Some(ref mut r) = self.router {
r.add_get(path, handler);
}
self
}
pub fn post<H, T>(mut self, path: &str, handler: H) -> Self
where
H: axum::handler::Handler<T, ()>,
T: 'static,
{
if let Some(ref mut r) = self.router {
r.add_post(path, handler);
}
self
}
pub fn put<H, T>(mut self, path: &str, handler: H) -> Self
where
H: axum::handler::Handler<T, ()>,
T: 'static,
{
if let Some(ref mut r) = self.router {
r.add_put(path, handler);
}
self
}
pub fn delete<H, T>(mut self, path: &str, handler: H) -> Self
where
H: axum::handler::Handler<T, ()>,
T: 'static,
{
if let Some(ref mut r) = self.router {
r.add_delete(path, handler);
}
self
}
pub fn route<H, T>(mut self, method: &str, path: &str, handler: H) -> Self
where
H: axum::handler::Handler<T, ()>,
T: 'static,
{
if let Some(ref mut r) = self.router {
r.add_route(method, path, handler);
}
self
}
pub fn group(mut self, prefix: &str, f: impl FnOnce(Self) -> Self) -> Self {
let sub = f(Self {
router: Some(AlunRouter::new()),
plugins: PluginManager::new(),
settings: AppSettings {
config_path: None,
gen_config_only: false,
print_config: false,
},
config_mgr: None,
prefix: String::new(),
rate_limit_store: Arc::new(RwLock::new(HashMap::new())),
custom_middleware_hook: None,
startup_hook: None,
});
if let (Some(ref mut r), Some(sub_r)) = (self.router.as_mut(), sub.router) {
r.merge(prefix, sub_r);
}
self
}
pub fn scan(mut self) -> Self {
for register in crate::ROUTES {
if let Some(ref mut r) = self.router {
register(r);
}
}
self
}
pub fn merge(mut self, prefix: &str, sub: AlunRouter) -> Self {
if let Some(ref mut r) = self.router {
r.merge(prefix, sub);
}
self
}
pub fn with_permission<H, T>(
mut self, method: &str, path: &str, handler: H, permission: &str,
) -> Self
where
H: axum::handler::Handler<T, ()>,
T: 'static,
{
let perm = permission.to_string();
if let Some(ref mut r) = self.router {
let wrap = move |mr: axum::routing::MethodRouter<()>| {
mr.route_layer(mw::RequirePermissionLayer::any(vec![perm]))
};
match method.to_uppercase().as_str() {
"GET" => r.add_get_with_layer(path, handler, wrap),
"POST" => r.add_post_with_layer(path, handler, wrap),
"PUT" => r.add_put_with_layer(path, handler, wrap),
"DELETE" => r.add_delete_with_layer(path, handler, wrap),
_ => r.add_get_with_layer(path, handler, wrap),
};
}
self
}
pub fn with_role<H, T>(
mut self, method: &str, path: &str, handler: H, role: &str,
) -> Self
where
H: axum::handler::Handler<T, ()>,
T: 'static,
{
let rl = role.to_string();
if let Some(ref mut r) = self.router {
let wrap = move |mr: axum::routing::MethodRouter<()>| {
mr.route_layer(mw::RequireRoleLayer::any(vec![rl]))
};
match method.to_uppercase().as_str() {
"GET" => r.add_get_with_layer(path, handler, wrap),
"POST" => r.add_post_with_layer(path, handler, wrap),
"PUT" => r.add_put_with_layer(path, handler, wrap),
"DELETE" => r.add_delete_with_layer(path, handler, wrap),
_ => r.add_get_with_layer(path, handler, wrap),
};
}
self
}
pub fn plugin<P: alun_core::Plugin + 'static>(mut self, plugin: P) -> Self {
self.plugins = self.plugins.add(plugin);
self
}
pub fn on_startup<F, Fut>(mut self, hook: F) -> Self
where
F: FnOnce() -> Fut + Send + 'static,
Fut: std::future::Future<Output = ()> + Send + 'static,
{
self.startup_hook = Some(Box::new(|| Box::pin(hook())));
self
}
pub fn with_middleware_hook<F>(mut self, hook: F) -> Self
where
F: FnOnce(Router) -> Router + Send + 'static,
{
self.custom_middleware_hook = Some(Box::new(hook));
self
}
pub async fn start(mut self) -> Result<()> {
let startup_start = Instant::now();
if self.settings.gen_config_only {
let dir = self.settings.config_path.as_deref().unwrap_or("config");
ConfigManager::generate_default(dir)?;
return Ok(());
}
if self.config_mgr.is_none() {
let _ = tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "alun=info".into()))
.try_init();
}
if self.settings.print_config {
if let Some(ref cm) = self.config_mgr {
if let Ok(toml_str) = toml::to_string_pretty(cm.get()) {
println!("{}", toml_str);
}
}
}
if let Some(ref cm) = self.config_mgr {
Self::init_global_resources(cm).await?;
}
if let Some(hook) = self.startup_hook.take() {
(hook)().await;
}
self.plugins.check_duplicate_names()
.map_err(alun_core::Error::Config)?;
self.plugins.start_all().await?;
let router = self.router.take().unwrap_or_default();
let mut axum_router: Router = router.into_axum();
axum_router = self.build_middleware_chain(axum_router);
if let Some(ref cm) = self.config_mgr {
let cfg = cm.get();
if cfg.static_files.enabled {
let static_path = cfg.static_files.path.clone();
std::fs::create_dir_all(&static_path).ok();
info!("静态文件服务就绪 path={}", static_path);
axum_router = axum_router.fallback_service(ServeDir::new(&static_path));
} else if cfg.router.not_found.enabled {
axum_router = axum_router.fallback(Self::handle_not_found);
}
}
if let Some(hook) = self.custom_middleware_hook.take() {
axum_router = hook(axum_router);
}
let bind_addr = self.config_mgr
.as_ref()
.map(|cm| cm.get().server.listen.clone())
.unwrap_or_else(|| "0.0.0.0:0".to_string());
let socket_addr = parse_addr(&bind_addr)?;
let display_addr = resolve_display_addr(socket_addr);
let app_name = self.config_mgr.as_ref().map(|cm| cm.get().app_name.as_str()).unwrap_or("Alun");
info!("{} 启动 -> http://{}", app_name, display_addr);
if let Some(cm) = &self.config_mgr {
info!(
" profile={}, request_id={} log={} cors={} compression={} rate_limit={} jwt_auth={} static_files={} not_found={}",
cm.get().profile,
cm.get().middleware.request_id,
cm.get().middleware.request_log,
cm.get().middleware.cors.enabled,
cm.get().middleware.compression.enabled,
cm.get().middleware.rate_limit.enabled,
cm.get().middleware.auth.enabled,
cm.get().static_files.enabled,
cm.get().router.not_found.enabled,
);
}
let startup_ms = startup_start.elapsed().as_millis();
info!("{} 启动完成, 耗时 {}ms", app_name, startup_ms);
let listener = tokio::net::TcpListener::bind(socket_addr).await?;
axum::serve(listener, axum_router.into_make_service_with_connect_info::<SocketAddr>())
.with_graceful_shutdown(shutdown_signal())
.await?;
self.plugins.stop_all().await;
Ok(())
}
pub async fn serve(self, addr: impl Into<String>) -> Result<()> {
let mut s = self;
let addr_str = addr.into();
if s.config_mgr.is_none() {
let default_cfg = AppConfig {
server: alun_config::ServerConfig {
listen: addr_str.clone(),
..Default::default()
},
..Default::default()
};
s.config_mgr = Some(Arc::new(ConfigManager {
static_config: default_cfg,
dynamic: parking_lot::RwLock::new(HashMap::new()),
}));
} else if let Some(ref cm) = s.config_mgr {
let mut cfg = cm.get().clone();
cfg.server.listen = addr_str.clone();
s.config_mgr = Some(Arc::new(ConfigManager {
static_config: cfg,
dynamic: parking_lot::RwLock::new(HashMap::new()),
}));
}
s.start().await
}
async fn init_global_resources(cm: &Arc<ConfigManager>) -> Result<()> {
set_config(cm.clone()).map_err(|e| alun_core::Error::Config(e.to_string()))?;
let cfg = cfg();
#[cfg(feature = "db")]
if cfg.database.enabled {
match alun_db::factory::create_db(&cfg.database).await {
Ok(db) => {
info!("数据库连接成功");
if cfg.database.migration.enabled && cfg.database.migration.auto_migrate {
let migrator = alun_db::migrate::Migrator::new(db.clone(), cfg.database.migration.clone());
match migrator.run().await {
Ok(records) => info!("数据库迁移完成: {:?}", records.iter().map(|r| &r.version).collect::<Vec<_>>()),
Err(e) => {
tracing::error!("数据库迁移失败: {}", e);
return Err(alun_core::Error::Config(format!("数据库迁移失败: {}", e)));
}
}
}
set_db(db).map_err(|e| alun_core::Error::Config(e.to_string()))?;
}
Err(e) => {
tracing::error!("数据库连接失败: {}", e);
return Err(alun_core::Error::Config(format!("数据库连接失败: {}", e)));
}
}
}
#[cfg(feature = "cache")]
if cfg.cache.r#type != "local" || cfg.cache.max_capacity > 0 {
match alun_cache::create_cache(&cfg.cache, &cfg.redis).await {
Ok(c) => {
set_cache(c).map_err(|e| alun_core::Error::Config(e.to_string()))?;
}
Err(e) => {
tracing::warn!("缓存初始化失败: {},将不使用缓存", e);
}
}
}
#[cfg(feature = "template")]
{
match alun_template::TemplateEngine::from_dir(&cfg.template.path) {
Ok(engine) => {
info!("模板引擎就绪 path={}", cfg.template.path);
set_template(engine).map_err(|e| alun_core::Error::Config(e.to_string()))?;
}
Err(e) => {
tracing::warn!("模板引擎初始化失败: {},将使用空引擎", e);
let _ = set_template(alun_template::TemplateEngine::new());
}
}
}
{
let upload_path = &cfg.upload.path;
std::fs::create_dir_all(upload_path).map_err(|e| {
alun_core::Error::Config(format!("创建上传目录失败 '{}': {}", upload_path, e))
})?;
set_upload_path(upload_path.clone())
.map_err(|e| alun_core::Error::Config(e.to_string()))?;
info!("上传目录就绪 path={} max_size_mb={}", upload_path, cfg.upload.max_size_mb);
}
{
let download_path = &cfg.download.path;
std::fs::create_dir_all(download_path).map_err(|e| {
alun_core::Error::Config(format!("创建下载目录失败 '{}': {}", download_path, e))
})?;
set_download_path(download_path.clone())
.map_err(|e| alun_core::Error::Config(e.to_string()))?;
info!("下载目录就绪 path={}", download_path);
}
Ok(())
}
async fn handle_not_found() -> impl IntoResponse {
let msg = cfg().router.not_found.message.clone();
(StatusCode::NOT_FOUND, Json(Res::<()>::fail(codes::NOT_FOUND, msg)))
}
fn build_middleware_chain(
&self,
mut router: Router,
) -> Router {
if let Some(ref cm) = self.config_mgr {
let cfg = cm.get();
if cfg.middleware.security_headers.enabled {
router = router.layer(mw::SecurityHeadersLayer::new(
cfg.middleware.security_headers.clone(),
));
}
if cfg.middleware.request_log {
let log_cfg = &cfg.middleware.request_log_config;
let prefix_excluded: Vec<String> = log_cfg.exclude_paths
.iter().map(|p| format!("{}{}", self.prefix, p)).collect();
let log_layer = mw::RequestLogLayer {
exclude_paths: prefix_excluded,
log_duration: log_cfg.log_duration,
};
router = router.layer(log_layer);
}
if cfg.middleware.request_id {
router = router.layer(mw::RequestIdLayer);
}
if cfg.middleware.cors.enabled {
let mut cors = CorsLayer::new();
if !cfg.middleware.cors.allow_origins.is_empty() {
let origins: Vec<HeaderValue> = cfg.middleware.cors.allow_origins
.iter().filter_map(|o| o.parse().ok()).collect();
cors = cors.allow_origin(AllowOrigin::list(origins));
} else {
cors = cors.allow_origin(AllowOrigin::any());
}
if !cfg.middleware.cors.allow_methods.is_empty() {
let methods: Vec<Method> = cfg.middleware.cors.allow_methods
.iter().filter_map(|m| m.parse().ok()).collect();
cors = cors.allow_methods(methods);
}
if !cfg.middleware.cors.allow_headers.is_empty() {
let headers: Vec<axum::http::HeaderName> = cfg.middleware.cors.allow_headers
.iter().filter_map(|h| h.parse().ok()).collect();
cors = cors.allow_headers(headers);
} else {
cors = cors.allow_headers(tower_http::cors::AllowHeaders::any());
}
if cfg.middleware.cors.allow_credentials {
cors = cors.allow_credentials(true);
}
cors = cors.max_age(std::time::Duration::from_secs(cfg.middleware.cors.max_age_secs));
router = router.layer(cors);
}
if cfg.middleware.compression.enabled {
router = router.layer(CompressionLayer::new().gzip(true));
}
if cfg.middleware.rate_limit.enabled {
let rl_layer = mw::RateLimitLayer {
requests_per_window: cfg.middleware.rate_limit.requests_per_window,
window_secs: cfg.middleware.rate_limit.window_secs,
store: self.rate_limit_store.clone(),
};
router = router.layer(rl_layer);
}
let mut perm_layer = mw::PermissionCheckLayer::from_config(&cfg.middleware.permission.rules);
perm_layer = perm_layer.with_macro_rules(&crate::PERMISSION_ROUTES);
if cfg.middleware.permission.enabled && perm_layer.has_rules() {
router = router.layer(perm_layer);
}
if cfg.middleware.auth.enabled && !cfg.middleware.auth.jwt_secret.is_empty() {
let mut ignore: Vec<String> = cfg.middleware.auth.ignore_paths
.iter().map(|p| format!("{}{}", self.prefix, p)).collect();
for def in crate::NO_AUTH_ROUTES {
let path_with_prefix = format!("{}{}", self.prefix, def.path);
if !ignore.contains(&path_with_prefix) {
ignore.push(path_with_prefix);
}
}
#[cfg(feature = "cache")]
let cache = try_cache().cloned();
let auth_layer = mw::AuthLayer {
jwt_secret: cfg.middleware.auth.jwt_secret.clone(),
ignore_paths: ignore,
#[cfg(feature = "cache")]
cache,
};
router = router.layer(auth_layer);
}
}
router
}
}
impl Default for App {
fn default() -> Self {
Self {
router: Some(AlunRouter::new()),
plugins: PluginManager::new(),
settings: AppSettings {
config_path: Some("config".into()),
gen_config_only: false,
print_config: false,
},
config_mgr: None,
prefix: String::new(),
rate_limit_store: Arc::new(RwLock::new(HashMap::new())),
custom_middleware_hook: None,
startup_hook: None,
}
}
}
fn parse_addr(addr: &str) -> alun_core::Result<SocketAddr> {
let addr = match addr {
a if a.starts_with(':') => format!("0.0.0.0{}", a),
a if !a.contains(':') => format!("0.0.0.0:{}", a),
a => a.to_string(),
};
addr.parse()
.map_err(|e| alun_core::Error::Config(format!("无效地址 '{}': {}", addr, e)))
}
fn resolve_display_addr(addr: SocketAddr) -> String {
if addr.ip().is_unspecified() {
format!("127.0.0.1:{}", addr.port())
} else {
addr.to_string()
}
}
async fn shutdown_signal() {
tokio::signal::ctrl_c().await.expect("Ctrl-C 注册失败");
info!("收到关闭信号,优雅退出中...");
}